janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
janus/translate.py DELETED
@@ -1,987 +0,0 @@
1
- import json
2
- import math
3
- import re
4
- import time
5
- from copy import deepcopy
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- from langchain.output_parsers import RetryWithErrorOutputParser
10
- from langchain.output_parsers.fix import OutputFixingParser
11
- from langchain_core.exceptions import OutputParserException
12
- from langchain_core.language_models import BaseLanguageModel
13
- from langchain_core.output_parsers import BaseOutputParser
14
- from langchain_core.prompts import ChatPromptTemplate
15
- from langchain_core.runnables import RunnableLambda, RunnableParallel
16
- from openai import BadRequestError, RateLimitError
17
- from text_generation.errors import ValidationError
18
-
19
- from .converter import Converter, run_if_changed
20
- from .embedding.vectorize import ChromaDBVectorizer
21
- from .language.block import CodeBlock, TranslatedCodeBlock
22
- from .language.combine import ChunkCombiner, Combiner, JsonCombiner
23
- from .language.naive.registry import CUSTOM_SPLITTERS
24
- from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
25
- from .llm import load_model
26
- from .llm.model_callbacks import get_model_callback
27
- from .llm.models_info import MODEL_PROMPT_ENGINES
28
- from .parsers.code_parser import CodeParser, GenericParser
29
- from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
30
- from .parsers.eval_parser import EvaluationParser
31
- from .parsers.reqs_parser import RequirementsParser
32
- from .prompts.prompt import (
33
- SAME_OUTPUT,
34
- TEXT_OUTPUT,
35
- retry_with_error_and_output_prompt,
36
- retry_with_output_prompt,
37
- )
38
- from .utils.enums import LANGUAGES
39
- from .utils.logger import create_logger
40
-
41
- log = create_logger(__name__)
42
-
43
-
44
- PARSER_TYPES: set[str] = {"code", "text", "eval", "madlibs", "multidoc", "requirements"}
45
-
46
-
47
- class Translator(Converter):
48
- """A class that translates code from one programming language to another."""
49
-
50
- def __init__(
51
- self,
52
- model: str = "gpt-3.5-turbo-0125",
53
- model_arguments: dict[str, Any] = {},
54
- source_language: str = "fortran",
55
- target_language: str = "python",
56
- target_version: str | None = "3.10",
57
- max_prompts: int = 10,
58
- max_tokens: int | None = None,
59
- prompt_template: str | Path = "simple",
60
- parser_type: str = "code",
61
- db_path: str | None = None,
62
- db_config: dict[str, Any] | None = None,
63
- custom_splitter: str | None = None,
64
- ) -> None:
65
- """Initialize a Translator instance.
66
-
67
- Arguments:
68
- model: The LLM to use for translation. If an OpenAI model, the
69
- `OPENAI_API_KEY` environment variable must be set and the
70
- `OPENAI_ORG_ID` environment variable should be set if needed.
71
- model_arguments: Additional arguments to pass to the LLM constructor.
72
- source_language: The source programming language.
73
- target_language: The target programming language.
74
- target_version: The target version of the target programming language.
75
- max_prompts: The maximum number of prompts to try before giving up.
76
- max_tokens: The maximum number of tokens the model will take in.
77
- If unspecificed, model's default max will be used.
78
- prompt_template: name of prompt template directory
79
- (see janus/prompts/templates) or path to a directory.
80
- parser_type: The type of parser to use for parsing the LLM output. Valid
81
- values are "code" (default), "text", and "eval".
82
- """
83
- self._custom_splitter = custom_splitter
84
- super().__init__(source_language=source_language)
85
-
86
- self._parser_type: str | None
87
- self._model_name: str | None
88
- self._custom_model_arguments: dict[str, Any] | None
89
- self._target_language: str | None
90
- self._target_version: str | None
91
- self._target_glob: str | None
92
- self._prompt_template_name: str | None
93
- self._db_path: str | None
94
- self._db_config: dict[str, Any] | None
95
-
96
- self._llm: BaseLanguageModel | None
97
- self._parser: BaseOutputParser | None
98
- self._combiner: Combiner | None
99
- self._prompt: ChatPromptTemplate | None
100
-
101
- self.max_prompts = max_prompts
102
- self.override_token_limit = False if max_tokens is None else True
103
- self._max_tokens = max_tokens
104
-
105
- self.set_model(model_name=model, **model_arguments)
106
- self.set_parser_type(parser_type=parser_type)
107
- self.set_prompt(prompt_template=prompt_template)
108
- self.set_target_language(
109
- target_language=target_language,
110
- target_version=target_version,
111
- )
112
- self.set_db_path(db_path=db_path)
113
- self.set_db_config(db_config=db_config)
114
-
115
- self._load_parameters()
116
-
117
- def _load_parameters(self) -> None:
118
- self._load_model()
119
- self._load_prompt()
120
- self._load_parser()
121
- self._load_combiner()
122
- self._load_vectorizer()
123
- super()._load_parameters() # will call self._changed_attrs.clear()
124
-
125
- def translate(
126
- self,
127
- input_directory: str | Path,
128
- output_directory: str | Path | None = None,
129
- overwrite: bool = False,
130
- collection_name: str | None = None,
131
- ) -> None:
132
- """Translate code in the input directory from the source language to the target
133
- language, and write the resulting files to the output directory.
134
-
135
- Arguments:
136
- input_directory: The directory containing the code to translate.
137
- output_directory: The directory to write the translated code to.
138
- overwrite: Whether to overwrite existing files (vs skip them)
139
- collection_name: Collection to add to
140
- """
141
- # Convert paths to pathlib Paths if needed
142
- if isinstance(input_directory, str):
143
- input_directory = Path(input_directory)
144
- if isinstance(output_directory, str):
145
- output_directory = Path(output_directory)
146
-
147
- # Make sure the output directory exists
148
- if output_directory is not None and not output_directory.exists():
149
- output_directory.mkdir(parents=True)
150
-
151
- source_suffix = LANGUAGES[self._source_language]["suffix"]
152
- target_suffix = LANGUAGES[self._target_language]["suffix"]
153
-
154
- input_paths = [p for p in input_directory.rglob(self._source_glob)]
155
-
156
- log.info(f"Input directory: {input_directory.absolute()}")
157
- log.info(
158
- f"{self._source_language.capitalize()} '*.{source_suffix}' files: "
159
- f"{len(input_paths)}"
160
- )
161
- log.info(
162
- "Other files (skipped): "
163
- f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
164
- )
165
- if output_directory is not None:
166
- output_paths = [
167
- output_directory
168
- / p.relative_to(input_directory).with_suffix(f".{target_suffix}")
169
- for p in input_paths
170
- ]
171
- in_out_pairs = list(zip(input_paths, output_paths))
172
- if not overwrite:
173
- n_files = len(in_out_pairs)
174
- in_out_pairs = [
175
- (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
176
- ]
177
- log.info(
178
- f"Skipping {n_files - len(in_out_pairs)} existing "
179
- f"'*.{source_suffix}' files"
180
- )
181
- else:
182
- in_out_pairs = [(f, None) for f in input_paths]
183
- log.info(f"Translating {len(in_out_pairs)} '*.{source_suffix}' files")
184
-
185
- # Now, loop through every code block in every file and translate it with an LLM
186
- total_cost = 0.0
187
- for in_path, out_path in in_out_pairs:
188
- # Translate the file, skip it if there's a rate limit error
189
- try:
190
- out_block = self.translate_file(in_path)
191
- total_cost += out_block.total_cost
192
- except RateLimitError:
193
- continue
194
- except OutputParserException as e:
195
- log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
196
- continue
197
- except BadRequestError as e:
198
- if str(e).startswith("Detected an error in the prompt"):
199
- log.warning("Malformed input, skipping")
200
- continue
201
- raise e
202
- except ValidationError as e:
203
- # Only allow ValidationError to pass if token limit is manually set
204
- if self.override_token_limit:
205
- log.warning(
206
- "Current file and manually set token "
207
- "limit is too large for this model, skipping"
208
- )
209
- continue
210
- raise e
211
- except TokenLimitError:
212
- log.warning("Ran into irreducible node too large for context, skipping")
213
- continue
214
- except EmptyTreeError:
215
- log.warning(
216
- f'Input file "{in_path.name}" has no nodes of interest, skipping'
217
- )
218
- continue
219
- except FileSizeError:
220
- log.warning("Current tile is too large for basic splitter, skipping")
221
- continue
222
-
223
- # Don't attempt to write files for which translation failed
224
- if not out_block.translated:
225
- continue
226
-
227
- # # maybe want target embeddings?
228
- # if self.outputting_requirements():
229
- # filename = str(relative)
230
- # embedding_type = EmbeddingType.REQUIREMENT
231
- # elif self.outputting_summary():
232
- # filename = str(relative)
233
- # embedding_type = EmbeddingType.SUMMARY
234
- # elif self.outputting_pseudocode():
235
- # filename = out_path.name
236
- # embedding_type = EmbeddingType.PSEUDO
237
- # else:
238
- # filename = out_path.name
239
- # embedding_type = EmbeddingType.TARGET
240
- #
241
- # self._embed_nodes_recursively(out_block, embedding_type, filename)
242
-
243
- if collection_name is not None:
244
- self._vectorizer.add_nodes_recursively(
245
- out_block,
246
- collection_name,
247
- in_path.name,
248
- )
249
- # out_text = self.parser.parse_combined_output(out_block.complete_text)
250
- # # Using same id naming convention from vectorize.py
251
- # ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
252
- # output_collection.upsert(ids=ids, documents=[out_text])
253
-
254
- # Make sure the tree's code has been consolidated at the top level
255
- # before writing to file
256
- self._combiner.combine(out_block)
257
- if out_path is not None and (overwrite or not out_path.exists()):
258
- self._save_to_file(out_block, out_path)
259
-
260
- log.info(f"Total cost: ${total_cost:,.2f}")
261
-
262
- def _split_file(self, file: Path) -> CodeBlock:
263
- filename = file.name
264
- log.info(f"[{filename}] Splitting file")
265
- root = self._splitter.split(file)
266
- log.info(
267
- f"[{filename}] File split into {root.n_descendents:,} blocks, "
268
- f"tree of height {root.height}"
269
- )
270
- log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
271
- return root
272
-
273
- def translate_file(self, file: Path) -> TranslatedCodeBlock:
274
- """Translate a single file.
275
-
276
- Arguments:
277
- file: Input path to file
278
-
279
- Returns:
280
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
281
- code is not guaranteed to be consolidated. To amend this, run
282
- `Combiner.combine_children` on the block.
283
- """
284
- self._load_parameters()
285
- filename = file.name
286
-
287
- input_block = self._split_file(file)
288
- t0 = time.time()
289
- output_block = self._iterative_translate(input_block)
290
- output_block.processing_time = time.time() - t0
291
- if output_block.translated:
292
- completeness = output_block.translation_completeness
293
- log.info(
294
- f"[{filename}] Translation complete\n"
295
- f" {completeness:.2%} of input successfully translated\n"
296
- f" Total cost: ${output_block.total_cost:,.2f}\n"
297
- f" Total retries: {output_block.total_retries:,d}\n"
298
- f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
299
- )
300
-
301
- else:
302
- log.error(
303
- f"[{filename}] Translation failed\n"
304
- f" Total cost: ${output_block.total_cost:,.2f}\n"
305
- f" Total retries: {output_block.total_retries:,d}\n"
306
- )
307
- return output_block
308
-
309
- def outputting_requirements(self) -> bool:
310
- """Is the output of the translator a requirements file?"""
311
- # expect we will revise system to output more than a single output
312
- # so this is placeholder logic
313
- return self._prompt_template_name == "requirements"
314
-
315
- def outputting_summary(self) -> bool:
316
- """Is the output of the translator a summary documentation?"""
317
- return self._prompt_template_name == "document"
318
-
319
- def outputting_pseudocode(self) -> bool:
320
- """Is the output of the translator pseudocode?"""
321
- # expect we will revise system to output more than a single output
322
- # so this is placeholder logic
323
- return self._prompt_template_name == "pseudocode"
324
-
325
- def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
326
- """Translate the passed CodeBlock representing a full file.
327
-
328
- Arguments:
329
- root: A root block representing the top-level block of a file
330
-
331
- Returns:
332
- A `TranslatedCodeBlock`
333
- """
334
- translated_root = TranslatedCodeBlock(root, self._target_language)
335
- last_prog, prog_delta = 0, 0.1
336
- stack = [translated_root]
337
- while stack:
338
- translated_block = stack.pop()
339
-
340
- self._add_translation(translated_block)
341
-
342
- # If translating this block was unsuccessful, don't bother with its
343
- # children (they wouldn't show up in the final text anyway)
344
- if not translated_block.translated:
345
- continue
346
-
347
- stack.extend(translated_block.children)
348
-
349
- progress = translated_root.translation_completeness
350
- if progress - last_prog > prog_delta:
351
- last_prog = int(progress / prog_delta) * prog_delta
352
- log.info(f"[{root.name}] progress: {progress:.2%}")
353
-
354
- return translated_root
355
-
356
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
357
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
358
- `block.original`, setting the relevant fields in the translated block. The
359
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
360
- translates *only* the code for this block, not its children.
361
-
362
- Arguments:
363
- block: An empty `TranslatedCodeBlock`
364
- """
365
- if block.translated:
366
- return
367
-
368
- if block.original.text is None:
369
- block.translated = True
370
- return
371
-
372
- if self._llm is None:
373
- message = (
374
- "Model not configured correctly, cannot translate. Try setting "
375
- "the model"
376
- )
377
- log.error(message)
378
- raise ValueError(message)
379
-
380
- log.debug(f"[{block.name}] Translating...")
381
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
382
-
383
- # Track the cost of translating this block
384
- # TODO: If non-OpenAI models with prices are added, this will need
385
- # to be updated.
386
- with get_model_callback() as cb:
387
- t0 = time.time()
388
- block.text = self._run_chain(block)
389
- block.processing_time = time.time() - t0
390
- block.cost = cb.total_cost
391
- block.retries = max(0, cb.successful_requests - 1)
392
-
393
- block.tokens = self._llm.get_num_tokens(block.text)
394
- block.translated = True
395
-
396
- log.debug(f"[{block.name}] Output code:\n{block.text}")
397
-
398
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
399
- """Run the model with three nested error fixing schemes.
400
- First, try to fix simple formatting errors by giving the model just
401
- the output and the parsing error. After a number of attempts, try
402
- giving the model the output, the parsing error, and the original
403
- input. Again check/retry this output to solve for formatting errors.
404
- If we still haven't succeeded after several attempts, the model may
405
- be getting thrown off by a bad initial output; start from scratch
406
- and try again.
407
-
408
- The number of tries for each layer of this scheme is roughly equal
409
- to the cube root of self.max_retries, so the total calls to the
410
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
411
- """
412
- self._parser.set_reference(block.original)
413
-
414
- # Retries with just the format instructions, the output, and the error
415
- n1 = round(self.max_prompts ** (1 / 3))
416
-
417
- # Retries with the input, the output, and the error
418
- n2 = round((self.max_prompts // n1) ** (1 / 2))
419
-
420
- # Retries with just the input
421
- n3 = math.ceil(self.max_prompts / (n1 * n2))
422
-
423
- fix_format = OutputFixingParser.from_llm(
424
- llm=self._llm,
425
- parser=self._parser,
426
- max_retries=n1,
427
- prompt=retry_with_output_prompt,
428
- )
429
- retry = RetryWithErrorOutputParser.from_llm(
430
- llm=self._llm,
431
- parser=fix_format,
432
- max_retries=n2,
433
- prompt=retry_with_error_and_output_prompt,
434
- )
435
-
436
- completion_chain = self._prompt | self._llm
437
- chain = RunnableParallel(
438
- completion=completion_chain, prompt_value=self._prompt
439
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
440
-
441
- for _ in range(n3):
442
- try:
443
- return chain.invoke({"SOURCE_CODE": block.original.text})
444
- except OutputParserException:
445
- pass
446
-
447
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
448
-
449
- def _save_to_file(self, block: CodeBlock, out_path: Path) -> None:
450
- """Save a file to disk.
451
-
452
- Arguments:
453
- block: The `CodeBlock` to save to a file.
454
- """
455
- # TODO: can't use output fixer and this system for combining output
456
- out_text = self._parser.parse_combined_output(block.complete_text)
457
- out_path.parent.mkdir(parents=True, exist_ok=True)
458
- out_path.write_text(out_text, encoding="utf-8")
459
-
460
- def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
461
- """Validate and set the model name.
462
-
463
- The affected objects will not be updated until translate() is called.
464
-
465
- Arguments:
466
- model_name: The name of the model to use. Valid models are found in
467
- `janus.llm.models_info.MODEL_CONSTRUCTORS`.
468
- custom_arguments: Additional arguments to pass to the model constructor.
469
- """
470
- self._model_name = model_name
471
- self._custom_model_arguments = custom_arguments
472
-
473
- def set_parser_type(self, parser_type: str) -> None:
474
- """Validate and set the parser type.
475
-
476
- The affected objects will not be updated until translate() is called.
477
-
478
- Arguments:
479
- parser_type: The type of parser to use for parsing the LLM output. Valid
480
- values are "code" (default), "text", and "eval".
481
- """
482
- if parser_type not in PARSER_TYPES:
483
- raise ValueError(
484
- f'Unsupported parser type "{parser_type}". Valid types: '
485
- f"{PARSER_TYPES}"
486
- )
487
- self._parser_type = parser_type
488
-
489
- def set_prompt(self, prompt_template: str | Path) -> None:
490
- """Validate and set the prompt template name.
491
-
492
- The affected objects will not be updated until translate() is called.
493
-
494
- Arguments:
495
- prompt_template: name of prompt template directory
496
- (see janus/prompts/templates) or path to a directory.
497
- """
498
- self._prompt_template_name = prompt_template
499
-
500
- def set_target_language(
501
- self, target_language: str, target_version: str | None
502
- ) -> None:
503
- """Validate and set the target language.
504
-
505
- The affected objects will not be updated until translate() is called.
506
-
507
- Arguments:
508
- target_language: The target programming language.
509
- target_version: The target version of the target programming language.
510
- """
511
- target_language = target_language.lower()
512
- if target_language not in LANGUAGES:
513
- raise ValueError(
514
- f"Invalid target language: {target_language}. "
515
- "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
516
- )
517
- self._target_glob = f"**/*.{LANGUAGES[target_language]['suffix']}"
518
- self._target_language = target_language
519
- self._target_version = target_version
520
-
521
- def set_db_path(self, db_path: str) -> None:
522
- self._db_path = db_path
523
-
524
- def set_db_config(self, db_config: dict[str, Any] | None) -> None:
525
- self._db_config = db_config
526
-
527
- @run_if_changed("_model_name", "_custom_model_arguments")
528
- def _load_model(self) -> None:
529
- """Load the model according to this instance's attributes.
530
-
531
- If the relevant fields have not been changed since the last time this method was
532
- called, nothing happens.
533
- """
534
-
535
- # Get default arguments, set custom ones
536
- # model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
537
- # model_arguments.update(self._custom_model_arguments)
538
-
539
- # Load the model
540
- self._llm, token_limit, self.model_cost = load_model(self._model_name)
541
- # Set the max_tokens to less than half the model's limit to allow for enough
542
- # tokens at output
543
- # Only modify max_tokens if it is not specified by user
544
- if not self.override_token_limit:
545
- self._max_tokens = token_limit // 2.5
546
-
547
- @run_if_changed("_parser_type", "_target_language")
548
- def _load_parser(self) -> None:
549
- """Load the parser according to this instance's attributes.
550
-
551
- If the relevant fields have not been changed since the last time this method was
552
- called, nothing happens.
553
- """
554
- if "text" == self._target_language and self._parser_type != "text":
555
- raise ValueError(
556
- f"Target language ({self._target_language}) suggests target "
557
- f"parser should be 'text', but is '{self._parser_type}'"
558
- )
559
- if (
560
- self._parser_type in {"eval", "multidoc", "madlibs"}
561
- and "json" != self._target_language
562
- ):
563
- raise ValueError(
564
- f"Parser type ({self._parser_type}) suggests target language"
565
- f" should be 'json', but is '{self._target_language}'"
566
- )
567
- if "code" == self._parser_type:
568
- self._parser = CodeParser(language=self._target_language)
569
- elif "eval" == self._parser_type:
570
- self._parser = EvaluationParser()
571
- elif "multidoc" == self._parser_type:
572
- self._parser = MultiDocumentationParser()
573
- elif "madlibs" == self._parser_type:
574
- self._parser = MadlibsDocumentationParser()
575
- elif "text" == self._parser_type:
576
- self._parser = GenericParser()
577
- elif "requirements" == self._parser_type:
578
- self._parser = RequirementsParser()
579
- else:
580
- raise ValueError(
581
- f"Unsupported parser type: {self._parser_type}. Can be: "
582
- f"{PARSER_TYPES}"
583
- )
584
-
585
- @run_if_changed(
586
- "_prompt_template_name",
587
- "_source_language",
588
- "_target_language",
589
- "_target_version",
590
- "_model_name",
591
- )
592
- def _load_prompt(self) -> None:
593
- """Load the prompt according to this instance's attributes.
594
-
595
- If the relevant fields have not been changed since the last time this
596
- method was called, nothing happens.
597
- """
598
- if self._prompt_template_name in SAME_OUTPUT:
599
- if self._target_language != self._source_language:
600
- raise ValueError(
601
- f"Prompt template ({self._prompt_template_name}) suggests "
602
- f"source and target languages should match, but do not "
603
- f"({self._source_language} != {self._target_language})"
604
- )
605
- if self._prompt_template_name in TEXT_OUTPUT and self._target_language != "text":
606
- raise ValueError(
607
- f"Prompt template ({self._prompt_template_name}) suggests target "
608
- f"language should be 'text', but is '{self._target_language}'"
609
- )
610
-
611
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
612
- source_language=self._source_language,
613
- target_language=self._target_language,
614
- target_version=self._target_version,
615
- prompt_template=self._prompt_template_name,
616
- )
617
- self._prompt = prompt_engine.prompt
618
-
619
- @run_if_changed("_db_path")
620
- def _load_vectorizer(self) -> None:
621
- if self._db_path is None:
622
- self._vectorizer = None
623
- return
624
- vectorizer_factory = ChromaDBVectorizer()
625
- self._vectorizer = vectorizer_factory.create_vectorizer(
626
- self._db_path, self._db_config
627
- )
628
-
629
- @run_if_changed(
630
- "_source_language",
631
- "_max_tokens",
632
- "_llm",
633
- "_protected_node_types",
634
- "_prune_node_types",
635
- )
636
- def _load_splitter(self) -> None:
637
- if self._custom_splitter is None:
638
- super()._load_splitter()
639
- else:
640
- kwargs = dict(
641
- max_tokens=self._max_tokens,
642
- model=self._llm,
643
- protected_node_types=self._protected_node_types,
644
- prune_node_types=self._prune_node_types,
645
- )
646
- # TODO: This should be configurable
647
- if self._custom_splitter == "tag":
648
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
649
- self._splitter = CUSTOM_SPLITTERS[self._custom_splitter](
650
- language=self._source_language, **kwargs
651
- )
652
-
653
- @run_if_changed("_target_language", "_parser_type")
654
- def _load_combiner(self) -> None:
655
- if self._parser_type == "requirements":
656
- self._combiner = ChunkCombiner()
657
- elif self._target_language == "json":
658
- self._combiner = JsonCombiner()
659
- else:
660
- self._combiner = Combiner()
661
-
662
-
663
- class Documenter(Translator):
664
- def __init__(
665
- self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
666
- ):
667
- kwargs.update(
668
- source_language=source_language,
669
- target_language="text",
670
- target_version=None,
671
- prompt_template="document",
672
- parser_type="text",
673
- )
674
- super().__init__(**kwargs)
675
-
676
- if drop_comments:
677
- comment_node_type = LANGUAGES[source_language].get(
678
- "comment_node_type", "comment"
679
- )
680
- self.set_prune_node_types([comment_node_type])
681
-
682
-
683
- class MultiDocumenter(Documenter):
684
- def __init__(self, **kwargs):
685
- super().__init__(**kwargs)
686
- self.set_prompt("multidocument")
687
- self.set_parser_type("multidoc")
688
- self.set_target_language("json", None)
689
-
690
-
691
- class MadLibsDocumenter(Documenter):
692
- def __init__(
693
- self,
694
- comments_per_request: int | None = None,
695
- **kwargs,
696
- ) -> None:
697
- kwargs.update(drop_comments=False)
698
- super().__init__(**kwargs)
699
- self.set_prompt("document_madlibs")
700
- self.set_parser_type("madlibs")
701
- self.set_target_language("json", None)
702
- self.comments_per_request = comments_per_request
703
-
704
- def _add_translation(self, block: TranslatedCodeBlock):
705
- if block.translated:
706
- return
707
-
708
- if block.original.text is None:
709
- block.translated = True
710
- return
711
-
712
- if self.comments_per_request is None:
713
- return super()._add_translation(block)
714
-
715
- comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
716
- comments = list(
717
- re.finditer(
718
- comment_pattern,
719
- block.original.text,
720
- )
721
- )
722
-
723
- if not comments:
724
- log.info(f"[{block.name}] Skipping commentless block")
725
- block.translated = True
726
- block.text = None
727
- block.complete = True
728
- return
729
-
730
- if len(comments) <= self.comments_per_request:
731
- return super()._add_translation(block)
732
-
733
- comment_group_indices = list(range(0, len(comments), self.comments_per_request))
734
- log.debug(
735
- f"[{block.name}] Block contains more than {self.comments_per_request}"
736
- f" comments, splitting {len(comments)} comments into"
737
- f" {len(comment_group_indices)} groups"
738
- )
739
-
740
- block.processing_time = 0
741
- block.cost = 0
742
- block.retries = 0
743
- obj = {}
744
- for i in range(0, len(comments), self.comments_per_request):
745
- # Split the text into the section containing comments of interest,
746
- # all the text prior to those comments, and all the text after them
747
- working_comments = comments[i : i + self.comments_per_request]
748
- start_idx = working_comments[0].start()
749
- end_idx = working_comments[-1].end()
750
- prefix = block.original.text[:start_idx]
751
- keeper = block.original.text[start_idx:end_idx]
752
- suffix = block.original.text[end_idx:]
753
-
754
- # Strip all comment placeholders outside of the section of interest
755
- prefix = re.sub(comment_pattern, "", prefix)
756
- suffix = re.sub(comment_pattern, "", suffix)
757
-
758
- # Build a new TranslatedBlock using the new working text
759
- working_copy = deepcopy(block.original)
760
- working_copy.text = prefix + keeper + suffix
761
- working_block = TranslatedCodeBlock(working_copy, self._target_language)
762
-
763
- # Run the LLM on the working text
764
- super()._add_translation(working_block)
765
-
766
- # Update metadata to include for all runs
767
- block.retries += working_block.retries
768
- block.cost += working_block.cost
769
- block.processing_time += working_block.processing_time
770
-
771
- # Update the output text to merge this section's output in
772
- out_text = self._parser.parse(working_block.text)
773
- obj.update(json.loads(out_text))
774
-
775
- self._parser.set_reference(block.original)
776
- block.text = self._parser.parse(json.dumps(obj))
777
- block.tokens = self._llm.get_num_tokens(block.text)
778
- block.translated = True
779
-
780
- def _get_obj(
781
- self, block: TranslatedCodeBlock
782
- ) -> dict[str, int | float | dict[str, str]]:
783
- out_text = self._parser.parse_combined_output(block.complete_text)
784
- obj = dict(
785
- retries=block.total_retries,
786
- cost=block.total_cost,
787
- processing_time=block.processing_time,
788
- comments=json.loads(out_text),
789
- )
790
- return obj
791
-
792
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
793
- """Save a file to disk.
794
-
795
- Arguments:
796
- block: The `CodeBlock` to save to a file.
797
- """
798
- obj = self._get_obj(block)
799
- out_path.parent.mkdir(parents=True, exist_ok=True)
800
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
801
-
802
-
803
- class DiagramGenerator(Documenter):
804
- """DiagramGenerator
805
-
806
- A class that translates code from one programming language to a set of diagrams.
807
- """
808
-
809
- def __init__(
810
- self,
811
- model: str = "gpt-3.5-turbo-0125",
812
- model_arguments: dict[str, Any] = {},
813
- source_language: str = "fortran",
814
- max_prompts: int = 10,
815
- db_path: str | None = None,
816
- db_config: dict[str, Any] | None = None,
817
- diagram_type="Activity",
818
- add_documentation=False,
819
- custom_splitter: str | None = None,
820
- ) -> None:
821
- """Initialize the DiagramGenerator class
822
-
823
- Arguments:
824
- model: The LLM to use for translation. If an OpenAI model, the
825
- `OPENAI_API_KEY` environment variable must be set and the
826
- `OPENAI_ORG_ID` environment variable should be set if needed.
827
- model_arguments: Additional arguments to pass to the LLM constructor.
828
- source_language: The source programming language.
829
- max_prompts: The maximum number of prompts to try before giving up.
830
- db_path: path to chroma database
831
- db_config: database configuraiton
832
- diagram_type: type of PLANTUML diagram to generate
833
- """
834
- super().__init__(
835
- model=model,
836
- model_arguments=model_arguments,
837
- source_language=source_language,
838
- max_prompts=max_prompts,
839
- db_path=db_path,
840
- db_config=db_config,
841
- custom_splitter=custom_splitter,
842
- )
843
- self._diagram_type = diagram_type
844
- self._add_documentation = add_documentation
845
- self._documenter = None
846
- self._model = model
847
- self._model_arguments = model_arguments
848
- self._max_prompts = max_prompts
849
- if add_documentation:
850
- self._diagram_prompt_template_name = "diagram_with_documentation"
851
- else:
852
- self._diagram_prompt_template_name = "diagram"
853
- self._load_diagram_prompt_engine()
854
-
855
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
856
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
857
- `block.original`, setting the relevant fields in the translated block. The
858
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
859
- translates *only* the code for this block, not its children.
860
-
861
- Arguments:
862
- block: An empty `TranslatedCodeBlock`
863
- """
864
- if block.translated:
865
- return
866
-
867
- if block.original.text is None:
868
- block.translated = True
869
- return
870
-
871
- if self._add_documentation:
872
- documentation_block = deepcopy(block)
873
- super()._add_translation(documentation_block)
874
- if not documentation_block.translated:
875
- message = "Error: unable to produce documentation for code block"
876
- log.message(message)
877
- raise ValueError(message)
878
- documentation = json.loads(documentation_block.text)["docstring"]
879
-
880
- if self._llm is None:
881
- message = (
882
- "Model not configured correctly, cannot translate. Try setting "
883
- "the model"
884
- )
885
- log.error(message)
886
- raise ValueError(message)
887
-
888
- log.debug(f"[{block.name}] Translating...")
889
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
890
-
891
- self._parser.set_reference(block.original)
892
-
893
- query_and_parse = self.diagram_prompt | self._llm | self._parser
894
-
895
- if self._add_documentation:
896
- block.text = query_and_parse.invoke(
897
- {
898
- "SOURCE_CODE": block.original.text,
899
- "DIAGRAM_TYPE": self._diagram_type,
900
- "DOCUMENTATION": documentation,
901
- }
902
- )
903
- else:
904
- block.text = query_and_parse.invoke(
905
- {
906
- "SOURCE_CODE": block.original.text,
907
- "DIAGRAM_TYPE": self._diagram_type,
908
- }
909
- )
910
- block.tokens = self._llm.get_num_tokens(block.text)
911
- block.translated = True
912
-
913
- log.debug(f"[{block.name}] Output code:\n{block.text}")
914
-
915
- @run_if_changed(
916
- "_diagram_prompt_template_name",
917
- "_source_language",
918
- )
919
- def _load_diagram_prompt_engine(self) -> None:
920
- """Load the prompt engine according to this instance's attributes.
921
-
922
- If the relevant fields have not been changed since the last time this method was
923
- called, nothing happens.
924
- """
925
- if self._diagram_prompt_template_name in SAME_OUTPUT:
926
- if self._target_language != self._source_language:
927
- raise ValueError(
928
- f"Prompt template ({self._prompt_template_name}) suggests "
929
- f"source and target languages should match, but do not "
930
- f"({self._source_language} != {self._target_language})"
931
- )
932
- if (
933
- self._diagram_prompt_template_name in TEXT_OUTPUT
934
- and self._target_language != "text"
935
- ):
936
- raise ValueError(
937
- f"Prompt template ({self._prompt_template_name}) suggests target "
938
- f"language should be 'text', but is '{self._target_language}'"
939
- )
940
-
941
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
942
- source_language=self._source_language,
943
- target_language=self._target_language,
944
- target_version=self._target_version,
945
- prompt_template=self._diagram_prompt_template_name,
946
- )
947
- self.diagram_prompt = self._diagram_prompt_engine.prompt
948
-
949
-
950
- class RequirementsDocumenter(Documenter):
951
- """RequirementsGenerator
952
-
953
- A class that translates code from one programming language to its requirements.
954
- """
955
-
956
- def __init__(self, **kwargs):
957
- super().__init__(**kwargs)
958
- self.set_prompt("requirements")
959
- self.set_target_language("json", None)
960
- self.set_parser_type("requirements")
961
-
962
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
963
- """Save a file to disk.
964
-
965
- Arguments:
966
- block: The `CodeBlock` to save to a file.
967
- """
968
- output_list = list()
969
- # For each chunk of code, get generation metadata, the text of the code,
970
- # and the LLM generated requirements
971
- for child in block.children:
972
- code = child.original.text
973
- requirements = self._parser.parse_combined_output(child.complete_text)
974
- metadata = dict(
975
- retries=child.total_retries,
976
- cost=child.total_cost,
977
- processing_time=child.processing_time,
978
- )
979
- # Put them all in a top level 'output' key
980
- output_list.append(
981
- dict(metadata=metadata, code=code, requirements=requirements)
982
- )
983
- obj = dict(
984
- output=output_list,
985
- )
986
- out_path.parent.mkdir(parents=True, exist_ok=True)
987
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")