janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. janus/__init__.py +2 -2
  2. janus/__main__.py +1 -1
  3. janus/_tests/test_cli.py +1 -2
  4. janus/cli.py +43 -51
  5. janus/converter/__init__.py +6 -0
  6. janus/converter/_tests/__init__.py +0 -0
  7. janus/{_tests → converter/_tests}/test_translate.py +11 -22
  8. janus/converter/converter.py +614 -0
  9. janus/converter/diagram.py +124 -0
  10. janus/converter/document.py +131 -0
  11. janus/converter/evaluate.py +15 -0
  12. janus/converter/requirements.py +50 -0
  13. janus/converter/translate.py +108 -0
  14. janus/embedding/_tests/test_collections.py +2 -2
  15. janus/language/_tests/test_splitter.py +1 -1
  16. janus/language/alc/__init__.py +1 -0
  17. janus/language/alc/_tests/__init__.py +0 -0
  18. janus/language/alc/_tests/test_alc.py +28 -0
  19. janus/language/alc/alc.py +87 -0
  20. janus/language/block.py +4 -2
  21. janus/language/combine.py +0 -1
  22. janus/language/mumps/mumps.py +2 -3
  23. janus/language/naive/__init__.py +1 -1
  24. janus/language/naive/basic_splitter.py +4 -4
  25. janus/language/naive/chunk_splitter.py +4 -4
  26. janus/language/naive/registry.py +1 -1
  27. janus/language/naive/simple_ast.py +23 -12
  28. janus/language/naive/tag_splitter.py +4 -4
  29. janus/language/splitter.py +10 -4
  30. janus/language/treesitter/treesitter.py +26 -8
  31. janus/llm/model_callbacks.py +34 -37
  32. janus/llm/models_info.py +16 -3
  33. janus/metrics/_tests/test_llm.py +2 -3
  34. janus/metrics/_tests/test_rouge_score.py +1 -1
  35. janus/metrics/_tests/test_similarity_score.py +1 -1
  36. janus/metrics/complexity_metrics.py +3 -4
  37. janus/metrics/metric.py +3 -4
  38. janus/metrics/reading.py +27 -5
  39. janus/prompts/prompt.py +67 -7
  40. janus/utils/enums.py +6 -5
  41. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
  42. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
  43. janus/converter.py +0 -158
  44. janus/translate.py +0 -981
  45. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
  46. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
  47. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
janus/translate.py DELETED
@@ -1,981 +0,0 @@
1
- import json
2
- import math
3
- import re
4
- import time
5
- from copy import deepcopy
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- from langchain.output_parsers import RetryWithErrorOutputParser
10
- from langchain.output_parsers.fix import OutputFixingParser
11
- from langchain_core.exceptions import OutputParserException
12
- from langchain_core.language_models import BaseLanguageModel
13
- from langchain_core.output_parsers import BaseOutputParser
14
- from langchain_core.prompts import ChatPromptTemplate
15
- from langchain_core.runnables import RunnableLambda, RunnableParallel
16
- from openai import BadRequestError, RateLimitError
17
- from text_generation.errors import ValidationError
18
-
19
- from janus.language.naive.registry import CUSTOM_SPLITTERS
20
-
21
- from .converter import Converter, run_if_changed
22
- from .embedding.vectorize import ChromaDBVectorizer
23
- from .language.block import CodeBlock, TranslatedCodeBlock
24
- from .language.combine import ChunkCombiner, Combiner, JsonCombiner
25
- from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
26
- from .llm import load_model
27
- from .llm.model_callbacks import get_model_callback
28
- from .llm.models_info import MODEL_PROMPT_ENGINES
29
- from .parsers.code_parser import CodeParser, GenericParser
30
- from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
31
- from .parsers.eval_parser import EvaluationParser
32
- from .parsers.reqs_parser import RequirementsParser
33
- from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT
34
- from .utils.enums import LANGUAGES
35
- from .utils.logger import create_logger
36
-
37
- log = create_logger(__name__)
38
-
39
-
40
- PARSER_TYPES: set[str] = {"code", "text", "eval", "madlibs", "multidoc", "requirements"}
41
-
42
-
43
- class Translator(Converter):
44
- """A class that translates code from one programming language to another."""
45
-
46
- def __init__(
47
- self,
48
- model: str = "gpt-3.5-turbo-0125",
49
- model_arguments: dict[str, Any] = {},
50
- source_language: str = "fortran",
51
- target_language: str = "python",
52
- target_version: str | None = "3.10",
53
- max_prompts: int = 10,
54
- max_tokens: int | None = None,
55
- prompt_template: str | Path = "simple",
56
- parser_type: str = "code",
57
- db_path: str | None = None,
58
- db_config: dict[str, Any] | None = None,
59
- custom_splitter: str | None = None,
60
- ) -> None:
61
- """Initialize a Translator instance.
62
-
63
- Arguments:
64
- model: The LLM to use for translation. If an OpenAI model, the
65
- `OPENAI_API_KEY` environment variable must be set and the
66
- `OPENAI_ORG_ID` environment variable should be set if needed.
67
- model_arguments: Additional arguments to pass to the LLM constructor.
68
- source_language: The source programming language.
69
- target_language: The target programming language.
70
- target_version: The target version of the target programming language.
71
- max_prompts: The maximum number of prompts to try before giving up.
72
- max_tokens: The maximum number of tokens the model will take in.
73
- If unspecificed, model's default max will be used.
74
- prompt_template: name of prompt template directory
75
- (see janus/prompts/templates) or path to a directory.
76
- parser_type: The type of parser to use for parsing the LLM output. Valid
77
- values are "code" (default), "text", and "eval".
78
- """
79
- self._custom_splitter = custom_splitter
80
- super().__init__(source_language=source_language)
81
-
82
- self._parser_type: str | None
83
- self._model_name: str | None
84
- self._custom_model_arguments: dict[str, Any] | None
85
- self._target_language: str | None
86
- self._target_version: str | None
87
- self._target_glob: str | None
88
- self._prompt_template_name: str | None
89
- self._db_path: str | None
90
- self._db_config: dict[str, Any] | None
91
-
92
- self._llm: BaseLanguageModel | None
93
- self._parser: BaseOutputParser | None
94
- self._combiner: Combiner | None
95
- self._prompt: ChatPromptTemplate | None
96
-
97
- self.max_prompts = max_prompts
98
- self.override_token_limit = False if max_tokens is None else True
99
- self._max_tokens = max_tokens
100
-
101
- self.set_model(model_name=model, **model_arguments)
102
- self.set_parser_type(parser_type=parser_type)
103
- self.set_prompt(prompt_template=prompt_template)
104
- self.set_target_language(
105
- target_language=target_language,
106
- target_version=target_version,
107
- )
108
- self.set_db_path(db_path=db_path)
109
- self.set_db_config(db_config=db_config)
110
-
111
- self._load_parameters()
112
-
113
- def _load_parameters(self) -> None:
114
- self._load_model()
115
- self._load_prompt()
116
- self._load_parser()
117
- self._load_combiner()
118
- self._load_vectorizer()
119
- super()._load_parameters() # will call self._changed_attrs.clear()
120
-
121
- def translate(
122
- self,
123
- input_directory: str | Path,
124
- output_directory: str | Path | None = None,
125
- overwrite: bool = False,
126
- collection_name: str | None = None,
127
- ) -> None:
128
- """Translate code in the input directory from the source language to the target
129
- language, and write the resulting files to the output directory.
130
-
131
- Arguments:
132
- input_directory: The directory containing the code to translate.
133
- output_directory: The directory to write the translated code to.
134
- overwrite: Whether to overwrite existing files (vs skip them)
135
- collection_name: Collection to add to
136
- """
137
- # Convert paths to pathlib Paths if needed
138
- if isinstance(input_directory, str):
139
- input_directory = Path(input_directory)
140
- if isinstance(output_directory, str):
141
- output_directory = Path(output_directory)
142
-
143
- # Make sure the output directory exists
144
- if output_directory is not None and not output_directory.exists():
145
- output_directory.mkdir(parents=True)
146
-
147
- source_suffix = LANGUAGES[self._source_language]["suffix"]
148
- target_suffix = LANGUAGES[self._target_language]["suffix"]
149
-
150
- input_paths = [p for p in input_directory.rglob(self._source_glob)]
151
-
152
- log.info(f"Input directory: {input_directory.absolute()}")
153
- log.info(
154
- f"{self._source_language.capitalize()} '*.{source_suffix}' files: "
155
- f"{len(input_paths)}"
156
- )
157
- log.info(
158
- "Other files (skipped): "
159
- f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
160
- )
161
- if output_directory is not None:
162
- output_paths = [
163
- output_directory
164
- / p.relative_to(input_directory).with_suffix(f".{target_suffix}")
165
- for p in input_paths
166
- ]
167
- in_out_pairs = list(zip(input_paths, output_paths))
168
- if not overwrite:
169
- n_files = len(in_out_pairs)
170
- in_out_pairs = [
171
- (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
172
- ]
173
- log.info(
174
- f"Skipping {n_files - len(in_out_pairs)} existing "
175
- f"'*.{source_suffix}' files"
176
- )
177
- else:
178
- in_out_pairs = [(f, None) for f in input_paths]
179
- log.info(f"Translating {len(in_out_pairs)} '*.{source_suffix}' files")
180
-
181
- # Now, loop through every code block in every file and translate it with an LLM
182
- total_cost = 0.0
183
- for in_path, out_path in in_out_pairs:
184
- # Translate the file, skip it if there's a rate limit error
185
- try:
186
- out_block = self.translate_file(in_path)
187
- total_cost += out_block.total_cost
188
- except RateLimitError:
189
- continue
190
- except OutputParserException as e:
191
- log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
192
- continue
193
- except BadRequestError as e:
194
- if str(e).startswith("Detected an error in the prompt"):
195
- log.warning("Malformed input, skipping")
196
- continue
197
- raise e
198
- except ValidationError as e:
199
- # Only allow ValidationError to pass if token limit is manually set
200
- if self.override_token_limit:
201
- log.warning(
202
- "Current file and manually set token "
203
- "limit is too large for this model, skipping"
204
- )
205
- continue
206
- raise e
207
- except TokenLimitError:
208
- log.warning("Ran into irreducible node too large for context, skipping")
209
- continue
210
- except EmptyTreeError:
211
- log.warning(
212
- f'Input file "{in_path.name}" has no nodes of interest, skipping'
213
- )
214
- continue
215
- except FileSizeError:
216
- log.warning("Current tile is too large for basic splitter, skipping")
217
- continue
218
-
219
- # Don't attempt to write files for which translation failed
220
- if not out_block.translated:
221
- continue
222
-
223
- # # maybe want target embeddings?
224
- # if self.outputting_requirements():
225
- # filename = str(relative)
226
- # embedding_type = EmbeddingType.REQUIREMENT
227
- # elif self.outputting_summary():
228
- # filename = str(relative)
229
- # embedding_type = EmbeddingType.SUMMARY
230
- # elif self.outputting_pseudocode():
231
- # filename = out_path.name
232
- # embedding_type = EmbeddingType.PSEUDO
233
- # else:
234
- # filename = out_path.name
235
- # embedding_type = EmbeddingType.TARGET
236
- #
237
- # self._embed_nodes_recursively(out_block, embedding_type, filename)
238
-
239
- if collection_name is not None:
240
- self._vectorizer.add_nodes_recursively(
241
- out_block,
242
- collection_name,
243
- in_path.name,
244
- )
245
- # out_text = self.parser.parse_combined_output(out_block.complete_text)
246
- # # Using same id naming convention from vectorize.py
247
- # ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
248
- # output_collection.upsert(ids=ids, documents=[out_text])
249
-
250
- # Make sure the tree's code has been consolidated at the top level
251
- # before writing to file
252
- self._combiner.combine(out_block)
253
- if out_path is not None and (overwrite or not out_path.exists()):
254
- self._save_to_file(out_block, out_path)
255
-
256
- log.info(f"Total cost: ${total_cost:,.2f}")
257
-
258
- def _split_file(self, file: Path) -> CodeBlock:
259
- filename = file.name
260
- log.info(f"[{filename}] Splitting file")
261
- root = self._splitter.split(file)
262
- log.info(
263
- f"[{filename}] File split into {root.n_descendents:,} blocks, "
264
- f"tree of height {root.height}"
265
- )
266
- log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
267
- return root
268
-
269
- def translate_file(self, file: Path) -> TranslatedCodeBlock:
270
- """Translate a single file.
271
-
272
- Arguments:
273
- file: Input path to file
274
-
275
- Returns:
276
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
277
- code is not guaranteed to be consolidated. To amend this, run
278
- `Combiner.combine_children` on the block.
279
- """
280
- self._load_parameters()
281
- filename = file.name
282
-
283
- input_block = self._split_file(file)
284
- t0 = time.time()
285
- output_block = self._iterative_translate(input_block)
286
- output_block.processing_time = time.time() - t0
287
- if output_block.translated:
288
- completeness = output_block.translation_completeness
289
- log.info(
290
- f"[{filename}] Translation complete\n"
291
- f" {completeness:.2%} of input successfully translated\n"
292
- f" Total cost: ${output_block.total_cost:,.2f}\n"
293
- f" Total retries: {output_block.total_retries:,d}\n"
294
- f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
295
- )
296
-
297
- else:
298
- log.error(
299
- f"[{filename}] Translation failed\n"
300
- f" Total cost: ${output_block.total_cost:,.2f}\n"
301
- f" Total retries: {output_block.total_retries:,d}\n"
302
- )
303
- return output_block
304
-
305
- def outputting_requirements(self) -> bool:
306
- """Is the output of the translator a requirements file?"""
307
- # expect we will revise system to output more than a single output
308
- # so this is placeholder logic
309
- return self._prompt_template_name == "requirements"
310
-
311
- def outputting_summary(self) -> bool:
312
- """Is the output of the translator a summary documentation?"""
313
- return self._prompt_template_name == "document"
314
-
315
- def outputting_pseudocode(self) -> bool:
316
- """Is the output of the translator pseudocode?"""
317
- # expect we will revise system to output more than a single output
318
- # so this is placeholder logic
319
- return self._prompt_template_name == "pseudocode"
320
-
321
- def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
322
- """Translate the passed CodeBlock representing a full file.
323
-
324
- Arguments:
325
- root: A root block representing the top-level block of a file
326
-
327
- Returns:
328
- A `TranslatedCodeBlock`
329
- """
330
- translated_root = TranslatedCodeBlock(root, self._target_language)
331
- last_prog, prog_delta = 0, 0.1
332
- stack = [translated_root]
333
- while stack:
334
- translated_block = stack.pop()
335
-
336
- self._add_translation(translated_block)
337
-
338
- # If translating this block was unsuccessful, don't bother with its
339
- # children (they wouldn't show up in the final text anyway)
340
- if not translated_block.translated:
341
- continue
342
-
343
- stack.extend(translated_block.children)
344
-
345
- progress = translated_root.translation_completeness
346
- if progress - last_prog > prog_delta:
347
- last_prog = int(progress / prog_delta) * prog_delta
348
- log.info(f"[{root.name}] progress: {progress:.2%}")
349
-
350
- return translated_root
351
-
352
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
353
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
354
- `block.original`, setting the relevant fields in the translated block. The
355
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
356
- translates *only* the code for this block, not its children.
357
-
358
- Arguments:
359
- block: An empty `TranslatedCodeBlock`
360
- """
361
- if block.translated:
362
- return
363
-
364
- if block.original.text is None:
365
- block.translated = True
366
- return
367
-
368
- if self._llm is None:
369
- message = (
370
- "Model not configured correctly, cannot translate. Try setting "
371
- "the model"
372
- )
373
- log.error(message)
374
- raise ValueError(message)
375
-
376
- log.debug(f"[{block.name}] Translating...")
377
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
378
-
379
- # Track the cost of translating this block
380
- # TODO: If non-OpenAI models with prices are added, this will need
381
- # to be updated.
382
- with get_model_callback() as cb:
383
- t0 = time.time()
384
- block.text = self._run_chain(block)
385
- block.processing_time = time.time() - t0
386
- block.cost = cb.total_cost
387
- block.retries = max(0, cb.successful_requests - 1)
388
-
389
- block.tokens = self._llm.get_num_tokens(block.text)
390
- block.translated = True
391
-
392
- log.debug(f"[{block.name}] Output code:\n{block.text}")
393
-
394
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
395
- """Run the model with three nested error fixing schemes.
396
- First, try to fix simple formatting errors by giving the model just
397
- the output and the parsing error. After a number of attempts, try
398
- giving the model the output, the parsing error, and the original
399
- input. Again check/retry this output to solve for formatting errors.
400
- If we still haven't succeeded after several attempts, the model may
401
- be getting thrown off by a bad initial output; start from scratch
402
- and try again.
403
-
404
- The number of tries for each layer of this scheme is roughly equal
405
- to the cube root of self.max_retries, so the total calls to the
406
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
407
- """
408
- self._parser.set_reference(block.original)
409
-
410
- # Retries with just the output and the error
411
- n1 = round(self.max_prompts ** (1 / 3))
412
-
413
- # Retries with the input, output, and error
414
- n2 = round((self.max_prompts // n1) ** (1 / 2))
415
-
416
- # Retries with just the input
417
- n3 = math.ceil(self.max_prompts / (n1 * n2))
418
-
419
- fix_format = OutputFixingParser.from_llm(
420
- llm=self._llm,
421
- parser=self._parser,
422
- max_retries=n1,
423
- )
424
- retry = RetryWithErrorOutputParser.from_llm(
425
- llm=self._llm,
426
- parser=fix_format,
427
- max_retries=n2,
428
- )
429
-
430
- completion_chain = self._prompt | self._llm
431
- chain = RunnableParallel(
432
- completion=completion_chain, prompt_value=self._prompt
433
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
434
-
435
- for _ in range(n3):
436
- try:
437
- return chain.invoke({"SOURCE_CODE": block.original.text})
438
- except OutputParserException:
439
- pass
440
-
441
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
442
-
443
- def _save_to_file(self, block: CodeBlock, out_path: Path) -> None:
444
- """Save a file to disk.
445
-
446
- Arguments:
447
- block: The `CodeBlock` to save to a file.
448
- """
449
- # TODO: can't use output fixer and this system for combining output
450
- out_text = self._parser.parse_combined_output(block.complete_text)
451
- out_path.parent.mkdir(parents=True, exist_ok=True)
452
- out_path.write_text(out_text, encoding="utf-8")
453
-
454
- def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
455
- """Validate and set the model name.
456
-
457
- The affected objects will not be updated until translate() is called.
458
-
459
- Arguments:
460
- model_name: The name of the model to use. Valid models are found in
461
- `janus.llm.models_info.MODEL_CONSTRUCTORS`.
462
- custom_arguments: Additional arguments to pass to the model constructor.
463
- """
464
- self._model_name = model_name
465
- self._custom_model_arguments = custom_arguments
466
-
467
- def set_parser_type(self, parser_type: str) -> None:
468
- """Validate and set the parser type.
469
-
470
- The affected objects will not be updated until translate() is called.
471
-
472
- Arguments:
473
- parser_type: The type of parser to use for parsing the LLM output. Valid
474
- values are "code" (default), "text", and "eval".
475
- """
476
- if parser_type not in PARSER_TYPES:
477
- raise ValueError(
478
- f'Unsupported parser type "{parser_type}". Valid types: '
479
- f"{PARSER_TYPES}"
480
- )
481
- self._parser_type = parser_type
482
-
483
- def set_prompt(self, prompt_template: str | Path) -> None:
484
- """Validate and set the prompt template name.
485
-
486
- The affected objects will not be updated until translate() is called.
487
-
488
- Arguments:
489
- prompt_template: name of prompt template directory
490
- (see janus/prompts/templates) or path to a directory.
491
- """
492
- self._prompt_template_name = prompt_template
493
-
494
- def set_target_language(
495
- self, target_language: str, target_version: str | None
496
- ) -> None:
497
- """Validate and set the target language.
498
-
499
- The affected objects will not be updated until translate() is called.
500
-
501
- Arguments:
502
- target_language: The target programming language.
503
- target_version: The target version of the target programming language.
504
- """
505
- target_language = target_language.lower()
506
- if target_language not in LANGUAGES:
507
- raise ValueError(
508
- f"Invalid target language: {target_language}. "
509
- "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
510
- )
511
- self._target_glob = f"**/*.{LANGUAGES[target_language]['suffix']}"
512
- self._target_language = target_language
513
- self._target_version = target_version
514
-
515
- def set_db_path(self, db_path: str) -> None:
516
- self._db_path = db_path
517
-
518
- def set_db_config(self, db_config: dict[str, Any] | None) -> None:
519
- self._db_config = db_config
520
-
521
- @run_if_changed("_model_name", "_custom_model_arguments")
522
- def _load_model(self) -> None:
523
- """Load the model according to this instance's attributes.
524
-
525
- If the relevant fields have not been changed since the last time this method was
526
- called, nothing happens.
527
- """
528
-
529
- # Get default arguments, set custom ones
530
- # model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
531
- # model_arguments.update(self._custom_model_arguments)
532
-
533
- # Load the model
534
- self._llm, token_limit, self.model_cost = load_model(self._model_name)
535
- # Set the max_tokens to less than half the model's limit to allow for enough
536
- # tokens at output
537
- # Only modify max_tokens if it is not specified by user
538
- if not self.override_token_limit:
539
- self._max_tokens = token_limit // 2.5
540
-
541
- @run_if_changed("_parser_type", "_target_language")
542
- def _load_parser(self) -> None:
543
- """Load the parser according to this instance's attributes.
544
-
545
- If the relevant fields have not been changed since the last time this method was
546
- called, nothing happens.
547
- """
548
- if "text" == self._target_language and self._parser_type != "text":
549
- raise ValueError(
550
- f"Target language ({self._target_language}) suggests target "
551
- f"parser should be 'text', but is '{self._parser_type}'"
552
- )
553
- if (
554
- self._parser_type in {"eval", "multidoc", "madlibs"}
555
- and "json" != self._target_language
556
- ):
557
- raise ValueError(
558
- f"Parser type ({self._parser_type}) suggests target language"
559
- f" should be 'json', but is '{self._target_language}'"
560
- )
561
- if "code" == self._parser_type:
562
- self._parser = CodeParser(language=self._target_language)
563
- elif "eval" == self._parser_type:
564
- self._parser = EvaluationParser()
565
- elif "multidoc" == self._parser_type:
566
- self._parser = MultiDocumentationParser()
567
- elif "madlibs" == self._parser_type:
568
- self._parser = MadlibsDocumentationParser()
569
- elif "text" == self._parser_type:
570
- self._parser = GenericParser()
571
- elif "requirements" == self._parser_type:
572
- self._parser = RequirementsParser()
573
- else:
574
- raise ValueError(
575
- f"Unsupported parser type: {self._parser_type}. Can be: "
576
- f"{PARSER_TYPES}"
577
- )
578
-
579
- @run_if_changed(
580
- "_prompt_template_name",
581
- "_source_language",
582
- "_target_language",
583
- "_target_version",
584
- "_model_name",
585
- )
586
- def _load_prompt(self) -> None:
587
- """Load the prompt according to this instance's attributes.
588
-
589
- If the relevant fields have not been changed since the last time this
590
- method was called, nothing happens.
591
- """
592
- if self._prompt_template_name in SAME_OUTPUT:
593
- if self._target_language != self._source_language:
594
- raise ValueError(
595
- f"Prompt template ({self._prompt_template_name}) suggests "
596
- f"source and target languages should match, but do not "
597
- f"({self._source_language} != {self._target_language})"
598
- )
599
- if self._prompt_template_name in TEXT_OUTPUT and self._target_language != "text":
600
- raise ValueError(
601
- f"Prompt template ({self._prompt_template_name}) suggests target "
602
- f"language should be 'text', but is '{self._target_language}'"
603
- )
604
-
605
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
606
- source_language=self._source_language,
607
- target_language=self._target_language,
608
- target_version=self._target_version,
609
- prompt_template=self._prompt_template_name,
610
- )
611
- self._prompt = prompt_engine.prompt
612
-
613
- @run_if_changed("_db_path")
614
- def _load_vectorizer(self) -> None:
615
- if self._db_path is None:
616
- self._vectorizer = None
617
- return
618
- vectorizer_factory = ChromaDBVectorizer()
619
- self._vectorizer = vectorizer_factory.create_vectorizer(
620
- self._db_path, self._db_config
621
- )
622
-
623
- @run_if_changed(
624
- "_source_language",
625
- "_max_tokens",
626
- "_llm",
627
- "_protected_node_types",
628
- "_prune_node_types",
629
- )
630
- def _load_splitter(self) -> None:
631
- if self._custom_splitter is None:
632
- super()._load_splitter()
633
- else:
634
- kwargs = dict(
635
- max_tokens=self._max_tokens,
636
- model=self._llm,
637
- protected_node_types=self._protected_node_types,
638
- prune_node_types=self._prune_node_types,
639
- )
640
- # TODO: This should be configurable
641
- if self._custom_splitter == "tag":
642
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
643
- self._splitter = CUSTOM_SPLITTERS[self._custom_splitter](
644
- language=self._source_language, **kwargs
645
- )
646
-
647
- @run_if_changed("_target_language", "_parser_type")
648
- def _load_combiner(self) -> None:
649
- if self._parser_type == "requirements":
650
- self._combiner = ChunkCombiner()
651
- elif self._target_language == "json":
652
- self._combiner = JsonCombiner()
653
- else:
654
- self._combiner = Combiner()
655
-
656
-
657
- class Documenter(Translator):
658
- def __init__(
659
- self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
660
- ):
661
- kwargs.update(
662
- source_language=source_language,
663
- target_language="text",
664
- target_version=None,
665
- prompt_template="document",
666
- parser_type="text",
667
- )
668
- super().__init__(**kwargs)
669
-
670
- if drop_comments:
671
- comment_node_type = LANGUAGES[source_language].get(
672
- "comment_node_type", "comment"
673
- )
674
- self.set_prune_node_types([comment_node_type])
675
-
676
-
677
- class MultiDocumenter(Documenter):
678
- def __init__(self, **kwargs):
679
- super().__init__(**kwargs)
680
- self.set_prompt("multidocument")
681
- self.set_parser_type("multidoc")
682
- self.set_target_language("json", None)
683
-
684
-
685
- class MadLibsDocumenter(Documenter):
686
- def __init__(
687
- self,
688
- comments_per_request: int | None = None,
689
- **kwargs,
690
- ) -> None:
691
- kwargs.update(drop_comments=False)
692
- super().__init__(**kwargs)
693
- self.set_prompt("document_madlibs")
694
- self.set_parser_type("madlibs")
695
- self.set_target_language("json", None)
696
- self.comments_per_request = comments_per_request
697
-
698
- def _add_translation(self, block: TranslatedCodeBlock):
699
- if block.translated:
700
- return
701
-
702
- if block.original.text is None:
703
- block.translated = True
704
- return
705
-
706
- if self.comments_per_request is None:
707
- return super()._add_translation(block)
708
-
709
- comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
710
- comments = list(
711
- re.finditer(
712
- comment_pattern,
713
- block.original.text,
714
- )
715
- )
716
-
717
- if not comments:
718
- log.info(f"[{block.name}] Skipping commentless block")
719
- block.translated = True
720
- block.text = None
721
- block.complete = True
722
- return
723
-
724
- if len(comments) <= self.comments_per_request:
725
- return super()._add_translation(block)
726
-
727
- comment_group_indices = list(range(0, len(comments), self.comments_per_request))
728
- log.debug(
729
- f"[{block.name}] Block contains more than {self.comments_per_request}"
730
- f" comments, splitting {len(comments)} comments into"
731
- f" {len(comment_group_indices)} groups"
732
- )
733
-
734
- block.processing_time = 0
735
- block.cost = 0
736
- block.retries = 0
737
- obj = {}
738
- for i in range(0, len(comments), self.comments_per_request):
739
- # Split the text into the section containing comments of interest,
740
- # all the text prior to those comments, and all the text after them
741
- working_comments = comments[i : i + self.comments_per_request]
742
- start_idx = working_comments[0].start()
743
- end_idx = working_comments[-1].end()
744
- prefix = block.original.text[:start_idx]
745
- keeper = block.original.text[start_idx:end_idx]
746
- suffix = block.original.text[end_idx:]
747
-
748
- # Strip all comment placeholders outside of the section of interest
749
- prefix = re.sub(comment_pattern, "", prefix)
750
- suffix = re.sub(comment_pattern, "", suffix)
751
-
752
- # Build a new TranslatedBlock using the new working text
753
- working_copy = deepcopy(block.original)
754
- working_copy.text = prefix + keeper + suffix
755
- working_block = TranslatedCodeBlock(working_copy, self._target_language)
756
-
757
- # Run the LLM on the working text
758
- super()._add_translation(working_block)
759
-
760
- # Update metadata to include for all runs
761
- block.retries += working_block.retries
762
- block.cost += working_block.cost
763
- block.processing_time += working_block.processing_time
764
-
765
- # Update the output text to merge this section's output in
766
- out_text = self._parser.parse(working_block.text)
767
- obj.update(json.loads(out_text))
768
-
769
- self._parser.set_reference(block.original)
770
- block.text = self._parser.parse(json.dumps(obj))
771
- block.tokens = self._llm.get_num_tokens(block.text)
772
- block.translated = True
773
-
774
- def _get_obj(
775
- self, block: TranslatedCodeBlock
776
- ) -> dict[str, int | float | dict[str, str]]:
777
- out_text = self._parser.parse_combined_output(block.complete_text)
778
- obj = dict(
779
- retries=block.total_retries,
780
- cost=block.total_cost,
781
- processing_time=block.processing_time,
782
- comments=json.loads(out_text),
783
- )
784
- return obj
785
-
786
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
787
- """Save a file to disk.
788
-
789
- Arguments:
790
- block: The `CodeBlock` to save to a file.
791
- """
792
- obj = self._get_obj(block)
793
- out_path.parent.mkdir(parents=True, exist_ok=True)
794
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
795
-
796
-
797
- class DiagramGenerator(Documenter):
798
- """DiagramGenerator
799
-
800
- A class that translates code from one programming language to a set of diagrams.
801
- """
802
-
803
- def __init__(
804
- self,
805
- model: str = "gpt-3.5-turbo-0125",
806
- model_arguments: dict[str, Any] = {},
807
- source_language: str = "fortran",
808
- max_prompts: int = 10,
809
- db_path: str | None = None,
810
- db_config: dict[str, Any] | None = None,
811
- diagram_type="Activity",
812
- add_documentation=False,
813
- custom_splitter: str | None = None,
814
- ) -> None:
815
- """Initialize the DiagramGenerator class
816
-
817
- Arguments:
818
- model: The LLM to use for translation. If an OpenAI model, the
819
- `OPENAI_API_KEY` environment variable must be set and the
820
- `OPENAI_ORG_ID` environment variable should be set if needed.
821
- model_arguments: Additional arguments to pass to the LLM constructor.
822
- source_language: The source programming language.
823
- max_prompts: The maximum number of prompts to try before giving up.
824
- db_path: path to chroma database
825
- db_config: database configuraiton
826
- diagram_type: type of PLANTUML diagram to generate
827
- """
828
- super().__init__(
829
- model=model,
830
- model_arguments=model_arguments,
831
- source_language=source_language,
832
- max_prompts=max_prompts,
833
- db_path=db_path,
834
- db_config=db_config,
835
- custom_splitter=custom_splitter,
836
- )
837
- self._diagram_type = diagram_type
838
- self._add_documentation = add_documentation
839
- self._documenter = None
840
- self._model = model
841
- self._model_arguments = model_arguments
842
- self._max_prompts = max_prompts
843
- if add_documentation:
844
- self._diagram_prompt_template_name = "diagram_with_documentation"
845
- else:
846
- self._diagram_prompt_template_name = "diagram"
847
- self._load_diagram_prompt_engine()
848
-
849
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
850
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
851
- `block.original`, setting the relevant fields in the translated block. The
852
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
853
- translates *only* the code for this block, not its children.
854
-
855
- Arguments:
856
- block: An empty `TranslatedCodeBlock`
857
- """
858
- if block.translated:
859
- return
860
-
861
- if block.original.text is None:
862
- block.translated = True
863
- return
864
-
865
- if self._add_documentation:
866
- documentation_block = deepcopy(block)
867
- super()._add_translation(documentation_block)
868
- if not documentation_block.translated:
869
- message = "Error: unable to produce documentation for code block"
870
- log.message(message)
871
- raise ValueError(message)
872
- documentation = json.loads(documentation_block.text)["docstring"]
873
-
874
- if self._llm is None:
875
- message = (
876
- "Model not configured correctly, cannot translate. Try setting "
877
- "the model"
878
- )
879
- log.error(message)
880
- raise ValueError(message)
881
-
882
- log.debug(f"[{block.name}] Translating...")
883
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
884
-
885
- self._parser.set_reference(block.original)
886
-
887
- query_and_parse = self.diagram_prompt | self._llm | self._parser
888
-
889
- if self._add_documentation:
890
- block.text = query_and_parse.invoke(
891
- {
892
- "SOURCE_CODE": block.original.text,
893
- "DIAGRAM_TYPE": self._diagram_type,
894
- "DOCUMENTATION": documentation,
895
- }
896
- )
897
- else:
898
- block.text = query_and_parse.invoke(
899
- {
900
- "SOURCE_CODE": block.original.text,
901
- "DIAGRAM_TYPE": self._diagram_type,
902
- }
903
- )
904
- block.tokens = self._llm.get_num_tokens(block.text)
905
- block.translated = True
906
-
907
- log.debug(f"[{block.name}] Output code:\n{block.text}")
908
-
909
- @run_if_changed(
910
- "_diagram_prompt_template_name",
911
- "_source_language",
912
- )
913
- def _load_diagram_prompt_engine(self) -> None:
914
- """Load the prompt engine according to this instance's attributes.
915
-
916
- If the relevant fields have not been changed since the last time this method was
917
- called, nothing happens.
918
- """
919
- if self._diagram_prompt_template_name in SAME_OUTPUT:
920
- if self._target_language != self._source_language:
921
- raise ValueError(
922
- f"Prompt template ({self._prompt_template_name}) suggests "
923
- f"source and target languages should match, but do not "
924
- f"({self._source_language} != {self._target_language})"
925
- )
926
- if (
927
- self._diagram_prompt_template_name in TEXT_OUTPUT
928
- and self._target_language != "text"
929
- ):
930
- raise ValueError(
931
- f"Prompt template ({self._prompt_template_name}) suggests target "
932
- f"language should be 'text', but is '{self._target_language}'"
933
- )
934
-
935
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
936
- source_language=self._source_language,
937
- target_language=self._target_language,
938
- target_version=self._target_version,
939
- prompt_template=self._diagram_prompt_template_name,
940
- )
941
- self.diagram_prompt = self._diagram_prompt_engine.prompt
942
-
943
-
944
- class RequirementsDocumenter(Documenter):
945
- """RequirementsGenerator
946
-
947
- A class that translates code from one programming language to its requirements.
948
- """
949
-
950
- def __init__(self, **kwargs):
951
- super().__init__(**kwargs)
952
- self.set_prompt("requirements")
953
- self.set_target_language("json", None)
954
- self.set_parser_type("requirements")
955
-
956
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
957
- """Save a file to disk.
958
-
959
- Arguments:
960
- block: The `CodeBlock` to save to a file.
961
- """
962
- output_list = list()
963
- # For each chunk of code, get generation metadata, the text of the code,
964
- # and the LLM generated requirements
965
- for child in block.children:
966
- code = child.original.text
967
- requirements = self._parser.parse_combined_output(child.complete_text)
968
- metadata = dict(
969
- retries=child.total_retries,
970
- cost=child.total_cost,
971
- processing_time=child.processing_time,
972
- )
973
- # Put them all in a top level 'output' key
974
- output_list.append(
975
- dict(metadata=metadata, code=code, requirements=requirements)
976
- )
977
- obj = dict(
978
- output=output_list,
979
- )
980
- out_path.parent.mkdir(parents=True, exist_ok=True)
981
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")