janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/translate.py DELETED
@@ -1,987 +0,0 @@
1
- import json
2
- import math
3
- import re
4
- import time
5
- from copy import deepcopy
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- from langchain.output_parsers import RetryWithErrorOutputParser
10
- from langchain.output_parsers.fix import OutputFixingParser
11
- from langchain_core.exceptions import OutputParserException
12
- from langchain_core.language_models import BaseLanguageModel
13
- from langchain_core.output_parsers import BaseOutputParser
14
- from langchain_core.prompts import ChatPromptTemplate
15
- from langchain_core.runnables import RunnableLambda, RunnableParallel
16
- from openai import BadRequestError, RateLimitError
17
- from text_generation.errors import ValidationError
18
-
19
- from .converter import Converter, run_if_changed
20
- from .embedding.vectorize import ChromaDBVectorizer
21
- from .language.block import CodeBlock, TranslatedCodeBlock
22
- from .language.combine import ChunkCombiner, Combiner, JsonCombiner
23
- from .language.naive.registry import CUSTOM_SPLITTERS
24
- from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
25
- from .llm import load_model
26
- from .llm.model_callbacks import get_model_callback
27
- from .llm.models_info import MODEL_PROMPT_ENGINES
28
- from .parsers.code_parser import CodeParser, GenericParser
29
- from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
30
- from .parsers.eval_parser import EvaluationParser
31
- from .parsers.reqs_parser import RequirementsParser
32
- from .prompts.prompt import (
33
- SAME_OUTPUT,
34
- TEXT_OUTPUT,
35
- retry_with_error_and_output_prompt,
36
- retry_with_output_prompt,
37
- )
38
- from .utils.enums import LANGUAGES
39
- from .utils.logger import create_logger
40
-
41
- log = create_logger(__name__)
42
-
43
-
44
- PARSER_TYPES: set[str] = {"code", "text", "eval", "madlibs", "multidoc", "requirements"}
45
-
46
-
47
- class Translator(Converter):
48
- """A class that translates code from one programming language to another."""
49
-
50
- def __init__(
51
- self,
52
- model: str = "gpt-3.5-turbo-0125",
53
- model_arguments: dict[str, Any] = {},
54
- source_language: str = "fortran",
55
- target_language: str = "python",
56
- target_version: str | None = "3.10",
57
- max_prompts: int = 10,
58
- max_tokens: int | None = None,
59
- prompt_template: str | Path = "simple",
60
- parser_type: str = "code",
61
- db_path: str | None = None,
62
- db_config: dict[str, Any] | None = None,
63
- custom_splitter: str | None = None,
64
- ) -> None:
65
- """Initialize a Translator instance.
66
-
67
- Arguments:
68
- model: The LLM to use for translation. If an OpenAI model, the
69
- `OPENAI_API_KEY` environment variable must be set and the
70
- `OPENAI_ORG_ID` environment variable should be set if needed.
71
- model_arguments: Additional arguments to pass to the LLM constructor.
72
- source_language: The source programming language.
73
- target_language: The target programming language.
74
- target_version: The target version of the target programming language.
75
- max_prompts: The maximum number of prompts to try before giving up.
76
- max_tokens: The maximum number of tokens the model will take in.
77
- If unspecificed, model's default max will be used.
78
- prompt_template: name of prompt template directory
79
- (see janus/prompts/templates) or path to a directory.
80
- parser_type: The type of parser to use for parsing the LLM output. Valid
81
- values are "code" (default), "text", and "eval".
82
- """
83
- self._custom_splitter = custom_splitter
84
- super().__init__(source_language=source_language)
85
-
86
- self._parser_type: str | None
87
- self._model_name: str | None
88
- self._custom_model_arguments: dict[str, Any] | None
89
- self._target_language: str | None
90
- self._target_version: str | None
91
- self._target_glob: str | None
92
- self._prompt_template_name: str | None
93
- self._db_path: str | None
94
- self._db_config: dict[str, Any] | None
95
-
96
- self._llm: BaseLanguageModel | None
97
- self._parser: BaseOutputParser | None
98
- self._combiner: Combiner | None
99
- self._prompt: ChatPromptTemplate | None
100
-
101
- self.max_prompts = max_prompts
102
- self.override_token_limit = False if max_tokens is None else True
103
- self._max_tokens = max_tokens
104
-
105
- self.set_model(model_name=model, **model_arguments)
106
- self.set_parser_type(parser_type=parser_type)
107
- self.set_prompt(prompt_template=prompt_template)
108
- self.set_target_language(
109
- target_language=target_language,
110
- target_version=target_version,
111
- )
112
- self.set_db_path(db_path=db_path)
113
- self.set_db_config(db_config=db_config)
114
-
115
- self._load_parameters()
116
-
117
- def _load_parameters(self) -> None:
118
- self._load_model()
119
- self._load_prompt()
120
- self._load_parser()
121
- self._load_combiner()
122
- self._load_vectorizer()
123
- super()._load_parameters() # will call self._changed_attrs.clear()
124
-
125
- def translate(
126
- self,
127
- input_directory: str | Path,
128
- output_directory: str | Path | None = None,
129
- overwrite: bool = False,
130
- collection_name: str | None = None,
131
- ) -> None:
132
- """Translate code in the input directory from the source language to the target
133
- language, and write the resulting files to the output directory.
134
-
135
- Arguments:
136
- input_directory: The directory containing the code to translate.
137
- output_directory: The directory to write the translated code to.
138
- overwrite: Whether to overwrite existing files (vs skip them)
139
- collection_name: Collection to add to
140
- """
141
- # Convert paths to pathlib Paths if needed
142
- if isinstance(input_directory, str):
143
- input_directory = Path(input_directory)
144
- if isinstance(output_directory, str):
145
- output_directory = Path(output_directory)
146
-
147
- # Make sure the output directory exists
148
- if output_directory is not None and not output_directory.exists():
149
- output_directory.mkdir(parents=True)
150
-
151
- source_suffix = LANGUAGES[self._source_language]["suffix"]
152
- target_suffix = LANGUAGES[self._target_language]["suffix"]
153
-
154
- input_paths = [p for p in input_directory.rglob(self._source_glob)]
155
-
156
- log.info(f"Input directory: {input_directory.absolute()}")
157
- log.info(
158
- f"{self._source_language.capitalize()} '*.{source_suffix}' files: "
159
- f"{len(input_paths)}"
160
- )
161
- log.info(
162
- "Other files (skipped): "
163
- f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
164
- )
165
- if output_directory is not None:
166
- output_paths = [
167
- output_directory
168
- / p.relative_to(input_directory).with_suffix(f".{target_suffix}")
169
- for p in input_paths
170
- ]
171
- in_out_pairs = list(zip(input_paths, output_paths))
172
- if not overwrite:
173
- n_files = len(in_out_pairs)
174
- in_out_pairs = [
175
- (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
176
- ]
177
- log.info(
178
- f"Skipping {n_files - len(in_out_pairs)} existing "
179
- f"'*.{source_suffix}' files"
180
- )
181
- else:
182
- in_out_pairs = [(f, None) for f in input_paths]
183
- log.info(f"Translating {len(in_out_pairs)} '*.{source_suffix}' files")
184
-
185
- # Now, loop through every code block in every file and translate it with an LLM
186
- total_cost = 0.0
187
- for in_path, out_path in in_out_pairs:
188
- # Translate the file, skip it if there's a rate limit error
189
- try:
190
- out_block = self.translate_file(in_path)
191
- total_cost += out_block.total_cost
192
- except RateLimitError:
193
- continue
194
- except OutputParserException as e:
195
- log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
196
- continue
197
- except BadRequestError as e:
198
- if str(e).startswith("Detected an error in the prompt"):
199
- log.warning("Malformed input, skipping")
200
- continue
201
- raise e
202
- except ValidationError as e:
203
- # Only allow ValidationError to pass if token limit is manually set
204
- if self.override_token_limit:
205
- log.warning(
206
- "Current file and manually set token "
207
- "limit is too large for this model, skipping"
208
- )
209
- continue
210
- raise e
211
- except TokenLimitError:
212
- log.warning("Ran into irreducible node too large for context, skipping")
213
- continue
214
- except EmptyTreeError:
215
- log.warning(
216
- f'Input file "{in_path.name}" has no nodes of interest, skipping'
217
- )
218
- continue
219
- except FileSizeError:
220
- log.warning("Current tile is too large for basic splitter, skipping")
221
- continue
222
-
223
- # Don't attempt to write files for which translation failed
224
- if not out_block.translated:
225
- continue
226
-
227
- # # maybe want target embeddings?
228
- # if self.outputting_requirements():
229
- # filename = str(relative)
230
- # embedding_type = EmbeddingType.REQUIREMENT
231
- # elif self.outputting_summary():
232
- # filename = str(relative)
233
- # embedding_type = EmbeddingType.SUMMARY
234
- # elif self.outputting_pseudocode():
235
- # filename = out_path.name
236
- # embedding_type = EmbeddingType.PSEUDO
237
- # else:
238
- # filename = out_path.name
239
- # embedding_type = EmbeddingType.TARGET
240
- #
241
- # self._embed_nodes_recursively(out_block, embedding_type, filename)
242
-
243
- if collection_name is not None:
244
- self._vectorizer.add_nodes_recursively(
245
- out_block,
246
- collection_name,
247
- in_path.name,
248
- )
249
- # out_text = self.parser.parse_combined_output(out_block.complete_text)
250
- # # Using same id naming convention from vectorize.py
251
- # ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
252
- # output_collection.upsert(ids=ids, documents=[out_text])
253
-
254
- # Make sure the tree's code has been consolidated at the top level
255
- # before writing to file
256
- self._combiner.combine(out_block)
257
- if out_path is not None and (overwrite or not out_path.exists()):
258
- self._save_to_file(out_block, out_path)
259
-
260
- log.info(f"Total cost: ${total_cost:,.2f}")
261
-
262
- def _split_file(self, file: Path) -> CodeBlock:
263
- filename = file.name
264
- log.info(f"[{filename}] Splitting file")
265
- root = self._splitter.split(file)
266
- log.info(
267
- f"[{filename}] File split into {root.n_descendents:,} blocks, "
268
- f"tree of height {root.height}"
269
- )
270
- log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
271
- return root
272
-
273
- def translate_file(self, file: Path) -> TranslatedCodeBlock:
274
- """Translate a single file.
275
-
276
- Arguments:
277
- file: Input path to file
278
-
279
- Returns:
280
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
281
- code is not guaranteed to be consolidated. To amend this, run
282
- `Combiner.combine_children` on the block.
283
- """
284
- self._load_parameters()
285
- filename = file.name
286
-
287
- input_block = self._split_file(file)
288
- t0 = time.time()
289
- output_block = self._iterative_translate(input_block)
290
- output_block.processing_time = time.time() - t0
291
- if output_block.translated:
292
- completeness = output_block.translation_completeness
293
- log.info(
294
- f"[{filename}] Translation complete\n"
295
- f" {completeness:.2%} of input successfully translated\n"
296
- f" Total cost: ${output_block.total_cost:,.2f}\n"
297
- f" Total retries: {output_block.total_retries:,d}\n"
298
- f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
299
- )
300
-
301
- else:
302
- log.error(
303
- f"[{filename}] Translation failed\n"
304
- f" Total cost: ${output_block.total_cost:,.2f}\n"
305
- f" Total retries: {output_block.total_retries:,d}\n"
306
- )
307
- return output_block
308
-
309
- def outputting_requirements(self) -> bool:
310
- """Is the output of the translator a requirements file?"""
311
- # expect we will revise system to output more than a single output
312
- # so this is placeholder logic
313
- return self._prompt_template_name == "requirements"
314
-
315
- def outputting_summary(self) -> bool:
316
- """Is the output of the translator a summary documentation?"""
317
- return self._prompt_template_name == "document"
318
-
319
- def outputting_pseudocode(self) -> bool:
320
- """Is the output of the translator pseudocode?"""
321
- # expect we will revise system to output more than a single output
322
- # so this is placeholder logic
323
- return self._prompt_template_name == "pseudocode"
324
-
325
- def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
326
- """Translate the passed CodeBlock representing a full file.
327
-
328
- Arguments:
329
- root: A root block representing the top-level block of a file
330
-
331
- Returns:
332
- A `TranslatedCodeBlock`
333
- """
334
- translated_root = TranslatedCodeBlock(root, self._target_language)
335
- last_prog, prog_delta = 0, 0.1
336
- stack = [translated_root]
337
- while stack:
338
- translated_block = stack.pop()
339
-
340
- self._add_translation(translated_block)
341
-
342
- # If translating this block was unsuccessful, don't bother with its
343
- # children (they wouldn't show up in the final text anyway)
344
- if not translated_block.translated:
345
- continue
346
-
347
- stack.extend(translated_block.children)
348
-
349
- progress = translated_root.translation_completeness
350
- if progress - last_prog > prog_delta:
351
- last_prog = int(progress / prog_delta) * prog_delta
352
- log.info(f"[{root.name}] progress: {progress:.2%}")
353
-
354
- return translated_root
355
-
356
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
357
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
358
- `block.original`, setting the relevant fields in the translated block. The
359
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
360
- translates *only* the code for this block, not its children.
361
-
362
- Arguments:
363
- block: An empty `TranslatedCodeBlock`
364
- """
365
- if block.translated:
366
- return
367
-
368
- if block.original.text is None:
369
- block.translated = True
370
- return
371
-
372
- if self._llm is None:
373
- message = (
374
- "Model not configured correctly, cannot translate. Try setting "
375
- "the model"
376
- )
377
- log.error(message)
378
- raise ValueError(message)
379
-
380
- log.debug(f"[{block.name}] Translating...")
381
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
382
-
383
- # Track the cost of translating this block
384
- # TODO: If non-OpenAI models with prices are added, this will need
385
- # to be updated.
386
- with get_model_callback() as cb:
387
- t0 = time.time()
388
- block.text = self._run_chain(block)
389
- block.processing_time = time.time() - t0
390
- block.cost = cb.total_cost
391
- block.retries = max(0, cb.successful_requests - 1)
392
-
393
- block.tokens = self._llm.get_num_tokens(block.text)
394
- block.translated = True
395
-
396
- log.debug(f"[{block.name}] Output code:\n{block.text}")
397
-
398
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
399
- """Run the model with three nested error fixing schemes.
400
- First, try to fix simple formatting errors by giving the model just
401
- the output and the parsing error. After a number of attempts, try
402
- giving the model the output, the parsing error, and the original
403
- input. Again check/retry this output to solve for formatting errors.
404
- If we still haven't succeeded after several attempts, the model may
405
- be getting thrown off by a bad initial output; start from scratch
406
- and try again.
407
-
408
- The number of tries for each layer of this scheme is roughly equal
409
- to the cube root of self.max_retries, so the total calls to the
410
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
411
- """
412
- self._parser.set_reference(block.original)
413
-
414
- # Retries with just the format instructions, the output, and the error
415
- n1 = round(self.max_prompts ** (1 / 3))
416
-
417
- # Retries with the input, the output, and the error
418
- n2 = round((self.max_prompts // n1) ** (1 / 2))
419
-
420
- # Retries with just the input
421
- n3 = math.ceil(self.max_prompts / (n1 * n2))
422
-
423
- fix_format = OutputFixingParser.from_llm(
424
- llm=self._llm,
425
- parser=self._parser,
426
- max_retries=n1,
427
- prompt=retry_with_output_prompt,
428
- )
429
- retry = RetryWithErrorOutputParser.from_llm(
430
- llm=self._llm,
431
- parser=fix_format,
432
- max_retries=n2,
433
- prompt=retry_with_error_and_output_prompt,
434
- )
435
-
436
- completion_chain = self._prompt | self._llm
437
- chain = RunnableParallel(
438
- completion=completion_chain, prompt_value=self._prompt
439
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
440
-
441
- for _ in range(n3):
442
- try:
443
- return chain.invoke({"SOURCE_CODE": block.original.text})
444
- except OutputParserException:
445
- pass
446
-
447
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
448
-
449
- def _save_to_file(self, block: CodeBlock, out_path: Path) -> None:
450
- """Save a file to disk.
451
-
452
- Arguments:
453
- block: The `CodeBlock` to save to a file.
454
- """
455
- # TODO: can't use output fixer and this system for combining output
456
- out_text = self._parser.parse_combined_output(block.complete_text)
457
- out_path.parent.mkdir(parents=True, exist_ok=True)
458
- out_path.write_text(out_text, encoding="utf-8")
459
-
460
- def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
461
- """Validate and set the model name.
462
-
463
- The affected objects will not be updated until translate() is called.
464
-
465
- Arguments:
466
- model_name: The name of the model to use. Valid models are found in
467
- `janus.llm.models_info.MODEL_CONSTRUCTORS`.
468
- custom_arguments: Additional arguments to pass to the model constructor.
469
- """
470
- self._model_name = model_name
471
- self._custom_model_arguments = custom_arguments
472
-
473
- def set_parser_type(self, parser_type: str) -> None:
474
- """Validate and set the parser type.
475
-
476
- The affected objects will not be updated until translate() is called.
477
-
478
- Arguments:
479
- parser_type: The type of parser to use for parsing the LLM output. Valid
480
- values are "code" (default), "text", and "eval".
481
- """
482
- if parser_type not in PARSER_TYPES:
483
- raise ValueError(
484
- f'Unsupported parser type "{parser_type}". Valid types: '
485
- f"{PARSER_TYPES}"
486
- )
487
- self._parser_type = parser_type
488
-
489
- def set_prompt(self, prompt_template: str | Path) -> None:
490
- """Validate and set the prompt template name.
491
-
492
- The affected objects will not be updated until translate() is called.
493
-
494
- Arguments:
495
- prompt_template: name of prompt template directory
496
- (see janus/prompts/templates) or path to a directory.
497
- """
498
- self._prompt_template_name = prompt_template
499
-
500
- def set_target_language(
501
- self, target_language: str, target_version: str | None
502
- ) -> None:
503
- """Validate and set the target language.
504
-
505
- The affected objects will not be updated until translate() is called.
506
-
507
- Arguments:
508
- target_language: The target programming language.
509
- target_version: The target version of the target programming language.
510
- """
511
- target_language = target_language.lower()
512
- if target_language not in LANGUAGES:
513
- raise ValueError(
514
- f"Invalid target language: {target_language}. "
515
- "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
516
- )
517
- self._target_glob = f"**/*.{LANGUAGES[target_language]['suffix']}"
518
- self._target_language = target_language
519
- self._target_version = target_version
520
-
521
- def set_db_path(self, db_path: str) -> None:
522
- self._db_path = db_path
523
-
524
- def set_db_config(self, db_config: dict[str, Any] | None) -> None:
525
- self._db_config = db_config
526
-
527
- @run_if_changed("_model_name", "_custom_model_arguments")
528
- def _load_model(self) -> None:
529
- """Load the model according to this instance's attributes.
530
-
531
- If the relevant fields have not been changed since the last time this method was
532
- called, nothing happens.
533
- """
534
-
535
- # Get default arguments, set custom ones
536
- # model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
537
- # model_arguments.update(self._custom_model_arguments)
538
-
539
- # Load the model
540
- self._llm, token_limit, self.model_cost = load_model(self._model_name)
541
- # Set the max_tokens to less than half the model's limit to allow for enough
542
- # tokens at output
543
- # Only modify max_tokens if it is not specified by user
544
- if not self.override_token_limit:
545
- self._max_tokens = token_limit // 2.5
546
-
547
- @run_if_changed("_parser_type", "_target_language")
548
- def _load_parser(self) -> None:
549
- """Load the parser according to this instance's attributes.
550
-
551
- If the relevant fields have not been changed since the last time this method was
552
- called, nothing happens.
553
- """
554
- if "text" == self._target_language and self._parser_type != "text":
555
- raise ValueError(
556
- f"Target language ({self._target_language}) suggests target "
557
- f"parser should be 'text', but is '{self._parser_type}'"
558
- )
559
- if (
560
- self._parser_type in {"eval", "multidoc", "madlibs"}
561
- and "json" != self._target_language
562
- ):
563
- raise ValueError(
564
- f"Parser type ({self._parser_type}) suggests target language"
565
- f" should be 'json', but is '{self._target_language}'"
566
- )
567
- if "code" == self._parser_type:
568
- self._parser = CodeParser(language=self._target_language)
569
- elif "eval" == self._parser_type:
570
- self._parser = EvaluationParser()
571
- elif "multidoc" == self._parser_type:
572
- self._parser = MultiDocumentationParser()
573
- elif "madlibs" == self._parser_type:
574
- self._parser = MadlibsDocumentationParser()
575
- elif "text" == self._parser_type:
576
- self._parser = GenericParser()
577
- elif "requirements" == self._parser_type:
578
- self._parser = RequirementsParser()
579
- else:
580
- raise ValueError(
581
- f"Unsupported parser type: {self._parser_type}. Can be: "
582
- f"{PARSER_TYPES}"
583
- )
584
-
585
- @run_if_changed(
586
- "_prompt_template_name",
587
- "_source_language",
588
- "_target_language",
589
- "_target_version",
590
- "_model_name",
591
- )
592
- def _load_prompt(self) -> None:
593
- """Load the prompt according to this instance's attributes.
594
-
595
- If the relevant fields have not been changed since the last time this
596
- method was called, nothing happens.
597
- """
598
- if self._prompt_template_name in SAME_OUTPUT:
599
- if self._target_language != self._source_language:
600
- raise ValueError(
601
- f"Prompt template ({self._prompt_template_name}) suggests "
602
- f"source and target languages should match, but do not "
603
- f"({self._source_language} != {self._target_language})"
604
- )
605
- if self._prompt_template_name in TEXT_OUTPUT and self._target_language != "text":
606
- raise ValueError(
607
- f"Prompt template ({self._prompt_template_name}) suggests target "
608
- f"language should be 'text', but is '{self._target_language}'"
609
- )
610
-
611
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
612
- source_language=self._source_language,
613
- target_language=self._target_language,
614
- target_version=self._target_version,
615
- prompt_template=self._prompt_template_name,
616
- )
617
- self._prompt = prompt_engine.prompt
618
-
619
- @run_if_changed("_db_path")
620
- def _load_vectorizer(self) -> None:
621
- if self._db_path is None:
622
- self._vectorizer = None
623
- return
624
- vectorizer_factory = ChromaDBVectorizer()
625
- self._vectorizer = vectorizer_factory.create_vectorizer(
626
- self._db_path, self._db_config
627
- )
628
-
629
- @run_if_changed(
630
- "_source_language",
631
- "_max_tokens",
632
- "_llm",
633
- "_protected_node_types",
634
- "_prune_node_types",
635
- )
636
- def _load_splitter(self) -> None:
637
- if self._custom_splitter is None:
638
- super()._load_splitter()
639
- else:
640
- kwargs = dict(
641
- max_tokens=self._max_tokens,
642
- model=self._llm,
643
- protected_node_types=self._protected_node_types,
644
- prune_node_types=self._prune_node_types,
645
- )
646
- # TODO: This should be configurable
647
- if self._custom_splitter == "tag":
648
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
649
- self._splitter = CUSTOM_SPLITTERS[self._custom_splitter](
650
- language=self._source_language, **kwargs
651
- )
652
-
653
- @run_if_changed("_target_language", "_parser_type")
654
- def _load_combiner(self) -> None:
655
- if self._parser_type == "requirements":
656
- self._combiner = ChunkCombiner()
657
- elif self._target_language == "json":
658
- self._combiner = JsonCombiner()
659
- else:
660
- self._combiner = Combiner()
661
-
662
-
663
- class Documenter(Translator):
664
- def __init__(
665
- self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
666
- ):
667
- kwargs.update(
668
- source_language=source_language,
669
- target_language="text",
670
- target_version=None,
671
- prompt_template="document",
672
- parser_type="text",
673
- )
674
- super().__init__(**kwargs)
675
-
676
- if drop_comments:
677
- comment_node_type = LANGUAGES[source_language].get(
678
- "comment_node_type", "comment"
679
- )
680
- self.set_prune_node_types([comment_node_type])
681
-
682
-
683
- class MultiDocumenter(Documenter):
684
- def __init__(self, **kwargs):
685
- super().__init__(**kwargs)
686
- self.set_prompt("multidocument")
687
- self.set_parser_type("multidoc")
688
- self.set_target_language("json", None)
689
-
690
-
691
- class MadLibsDocumenter(Documenter):
692
- def __init__(
693
- self,
694
- comments_per_request: int | None = None,
695
- **kwargs,
696
- ) -> None:
697
- kwargs.update(drop_comments=False)
698
- super().__init__(**kwargs)
699
- self.set_prompt("document_madlibs")
700
- self.set_parser_type("madlibs")
701
- self.set_target_language("json", None)
702
- self.comments_per_request = comments_per_request
703
-
704
- def _add_translation(self, block: TranslatedCodeBlock):
705
- if block.translated:
706
- return
707
-
708
- if block.original.text is None:
709
- block.translated = True
710
- return
711
-
712
- if self.comments_per_request is None:
713
- return super()._add_translation(block)
714
-
715
- comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
716
- comments = list(
717
- re.finditer(
718
- comment_pattern,
719
- block.original.text,
720
- )
721
- )
722
-
723
- if not comments:
724
- log.info(f"[{block.name}] Skipping commentless block")
725
- block.translated = True
726
- block.text = None
727
- block.complete = True
728
- return
729
-
730
- if len(comments) <= self.comments_per_request:
731
- return super()._add_translation(block)
732
-
733
- comment_group_indices = list(range(0, len(comments), self.comments_per_request))
734
- log.debug(
735
- f"[{block.name}] Block contains more than {self.comments_per_request}"
736
- f" comments, splitting {len(comments)} comments into"
737
- f" {len(comment_group_indices)} groups"
738
- )
739
-
740
- block.processing_time = 0
741
- block.cost = 0
742
- block.retries = 0
743
- obj = {}
744
- for i in range(0, len(comments), self.comments_per_request):
745
- # Split the text into the section containing comments of interest,
746
- # all the text prior to those comments, and all the text after them
747
- working_comments = comments[i : i + self.comments_per_request]
748
- start_idx = working_comments[0].start()
749
- end_idx = working_comments[-1].end()
750
- prefix = block.original.text[:start_idx]
751
- keeper = block.original.text[start_idx:end_idx]
752
- suffix = block.original.text[end_idx:]
753
-
754
- # Strip all comment placeholders outside of the section of interest
755
- prefix = re.sub(comment_pattern, "", prefix)
756
- suffix = re.sub(comment_pattern, "", suffix)
757
-
758
- # Build a new TranslatedBlock using the new working text
759
- working_copy = deepcopy(block.original)
760
- working_copy.text = prefix + keeper + suffix
761
- working_block = TranslatedCodeBlock(working_copy, self._target_language)
762
-
763
- # Run the LLM on the working text
764
- super()._add_translation(working_block)
765
-
766
- # Update metadata to include for all runs
767
- block.retries += working_block.retries
768
- block.cost += working_block.cost
769
- block.processing_time += working_block.processing_time
770
-
771
- # Update the output text to merge this section's output in
772
- out_text = self._parser.parse(working_block.text)
773
- obj.update(json.loads(out_text))
774
-
775
- self._parser.set_reference(block.original)
776
- block.text = self._parser.parse(json.dumps(obj))
777
- block.tokens = self._llm.get_num_tokens(block.text)
778
- block.translated = True
779
-
780
- def _get_obj(
781
- self, block: TranslatedCodeBlock
782
- ) -> dict[str, int | float | dict[str, str]]:
783
- out_text = self._parser.parse_combined_output(block.complete_text)
784
- obj = dict(
785
- retries=block.total_retries,
786
- cost=block.total_cost,
787
- processing_time=block.processing_time,
788
- comments=json.loads(out_text),
789
- )
790
- return obj
791
-
792
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
793
- """Save a file to disk.
794
-
795
- Arguments:
796
- block: The `CodeBlock` to save to a file.
797
- """
798
- obj = self._get_obj(block)
799
- out_path.parent.mkdir(parents=True, exist_ok=True)
800
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
801
-
802
-
803
- class DiagramGenerator(Documenter):
804
- """DiagramGenerator
805
-
806
- A class that translates code from one programming language to a set of diagrams.
807
- """
808
-
809
- def __init__(
810
- self,
811
- model: str = "gpt-3.5-turbo-0125",
812
- model_arguments: dict[str, Any] = {},
813
- source_language: str = "fortran",
814
- max_prompts: int = 10,
815
- db_path: str | None = None,
816
- db_config: dict[str, Any] | None = None,
817
- diagram_type="Activity",
818
- add_documentation=False,
819
- custom_splitter: str | None = None,
820
- ) -> None:
821
- """Initialize the DiagramGenerator class
822
-
823
- Arguments:
824
- model: The LLM to use for translation. If an OpenAI model, the
825
- `OPENAI_API_KEY` environment variable must be set and the
826
- `OPENAI_ORG_ID` environment variable should be set if needed.
827
- model_arguments: Additional arguments to pass to the LLM constructor.
828
- source_language: The source programming language.
829
- max_prompts: The maximum number of prompts to try before giving up.
830
- db_path: path to chroma database
831
- db_config: database configuraiton
832
- diagram_type: type of PLANTUML diagram to generate
833
- """
834
- super().__init__(
835
- model=model,
836
- model_arguments=model_arguments,
837
- source_language=source_language,
838
- max_prompts=max_prompts,
839
- db_path=db_path,
840
- db_config=db_config,
841
- custom_splitter=custom_splitter,
842
- )
843
- self._diagram_type = diagram_type
844
- self._add_documentation = add_documentation
845
- self._documenter = None
846
- self._model = model
847
- self._model_arguments = model_arguments
848
- self._max_prompts = max_prompts
849
- if add_documentation:
850
- self._diagram_prompt_template_name = "diagram_with_documentation"
851
- else:
852
- self._diagram_prompt_template_name = "diagram"
853
- self._load_diagram_prompt_engine()
854
-
855
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
856
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
857
- `block.original`, setting the relevant fields in the translated block. The
858
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
859
- translates *only* the code for this block, not its children.
860
-
861
- Arguments:
862
- block: An empty `TranslatedCodeBlock`
863
- """
864
- if block.translated:
865
- return
866
-
867
- if block.original.text is None:
868
- block.translated = True
869
- return
870
-
871
- if self._add_documentation:
872
- documentation_block = deepcopy(block)
873
- super()._add_translation(documentation_block)
874
- if not documentation_block.translated:
875
- message = "Error: unable to produce documentation for code block"
876
- log.message(message)
877
- raise ValueError(message)
878
- documentation = json.loads(documentation_block.text)["docstring"]
879
-
880
- if self._llm is None:
881
- message = (
882
- "Model not configured correctly, cannot translate. Try setting "
883
- "the model"
884
- )
885
- log.error(message)
886
- raise ValueError(message)
887
-
888
- log.debug(f"[{block.name}] Translating...")
889
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
890
-
891
- self._parser.set_reference(block.original)
892
-
893
- query_and_parse = self.diagram_prompt | self._llm | self._parser
894
-
895
- if self._add_documentation:
896
- block.text = query_and_parse.invoke(
897
- {
898
- "SOURCE_CODE": block.original.text,
899
- "DIAGRAM_TYPE": self._diagram_type,
900
- "DOCUMENTATION": documentation,
901
- }
902
- )
903
- else:
904
- block.text = query_and_parse.invoke(
905
- {
906
- "SOURCE_CODE": block.original.text,
907
- "DIAGRAM_TYPE": self._diagram_type,
908
- }
909
- )
910
- block.tokens = self._llm.get_num_tokens(block.text)
911
- block.translated = True
912
-
913
- log.debug(f"[{block.name}] Output code:\n{block.text}")
914
-
915
- @run_if_changed(
916
- "_diagram_prompt_template_name",
917
- "_source_language",
918
- )
919
- def _load_diagram_prompt_engine(self) -> None:
920
- """Load the prompt engine according to this instance's attributes.
921
-
922
- If the relevant fields have not been changed since the last time this method was
923
- called, nothing happens.
924
- """
925
- if self._diagram_prompt_template_name in SAME_OUTPUT:
926
- if self._target_language != self._source_language:
927
- raise ValueError(
928
- f"Prompt template ({self._prompt_template_name}) suggests "
929
- f"source and target languages should match, but do not "
930
- f"({self._source_language} != {self._target_language})"
931
- )
932
- if (
933
- self._diagram_prompt_template_name in TEXT_OUTPUT
934
- and self._target_language != "text"
935
- ):
936
- raise ValueError(
937
- f"Prompt template ({self._prompt_template_name}) suggests target "
938
- f"language should be 'text', but is '{self._target_language}'"
939
- )
940
-
941
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
942
- source_language=self._source_language,
943
- target_language=self._target_language,
944
- target_version=self._target_version,
945
- prompt_template=self._diagram_prompt_template_name,
946
- )
947
- self.diagram_prompt = self._diagram_prompt_engine.prompt
948
-
949
-
950
- class RequirementsDocumenter(Documenter):
951
- """RequirementsGenerator
952
-
953
- A class that translates code from one programming language to its requirements.
954
- """
955
-
956
- def __init__(self, **kwargs):
957
- super().__init__(**kwargs)
958
- self.set_prompt("requirements")
959
- self.set_target_language("json", None)
960
- self.set_parser_type("requirements")
961
-
962
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
963
- """Save a file to disk.
964
-
965
- Arguments:
966
- block: The `CodeBlock` to save to a file.
967
- """
968
- output_list = list()
969
- # For each chunk of code, get generation metadata, the text of the code,
970
- # and the LLM generated requirements
971
- for child in block.children:
972
- code = child.original.text
973
- requirements = self._parser.parse_combined_output(child.complete_text)
974
- metadata = dict(
975
- retries=child.total_retries,
976
- cost=child.total_cost,
977
- processing_time=child.processing_time,
978
- )
979
- # Put them all in a top level 'output' key
980
- output_list.append(
981
- dict(metadata=metadata, code=code, requirements=requirements)
982
- )
983
- obj = dict(
984
- output=output_list,
985
- )
986
- out_path.parent.mkdir(parents=True, exist_ok=True)
987
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")