janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. janus/__init__.py +2 -2
  2. janus/__main__.py +1 -1
  3. janus/_tests/test_cli.py +1 -2
  4. janus/cli.py +43 -51
  5. janus/converter/__init__.py +6 -0
  6. janus/converter/_tests/__init__.py +0 -0
  7. janus/{_tests → converter/_tests}/test_translate.py +11 -22
  8. janus/converter/converter.py +614 -0
  9. janus/converter/diagram.py +124 -0
  10. janus/converter/document.py +131 -0
  11. janus/converter/evaluate.py +15 -0
  12. janus/converter/requirements.py +50 -0
  13. janus/converter/translate.py +108 -0
  14. janus/embedding/_tests/test_collections.py +2 -2
  15. janus/language/_tests/test_splitter.py +1 -1
  16. janus/language/alc/__init__.py +1 -0
  17. janus/language/alc/_tests/__init__.py +0 -0
  18. janus/language/alc/_tests/test_alc.py +28 -0
  19. janus/language/alc/alc.py +87 -0
  20. janus/language/block.py +4 -2
  21. janus/language/combine.py +0 -1
  22. janus/language/mumps/mumps.py +2 -3
  23. janus/language/naive/__init__.py +1 -1
  24. janus/language/naive/basic_splitter.py +4 -4
  25. janus/language/naive/chunk_splitter.py +4 -4
  26. janus/language/naive/registry.py +1 -1
  27. janus/language/naive/simple_ast.py +23 -12
  28. janus/language/naive/tag_splitter.py +4 -4
  29. janus/language/splitter.py +10 -4
  30. janus/language/treesitter/treesitter.py +26 -8
  31. janus/llm/model_callbacks.py +34 -37
  32. janus/llm/models_info.py +16 -3
  33. janus/metrics/_tests/test_llm.py +2 -3
  34. janus/metrics/_tests/test_rouge_score.py +1 -1
  35. janus/metrics/_tests/test_similarity_score.py +1 -1
  36. janus/metrics/complexity_metrics.py +3 -4
  37. janus/metrics/metric.py +3 -4
  38. janus/metrics/reading.py +27 -5
  39. janus/prompts/prompt.py +67 -7
  40. janus/utils/enums.py +6 -5
  41. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
  42. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
  43. janus/converter.py +0 -158
  44. janus/translate.py +0 -981
  45. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
  46. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
  47. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
janus/translate.py DELETED
@@ -1,981 +0,0 @@
1
- import json
2
- import math
3
- import re
4
- import time
5
- from copy import deepcopy
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- from langchain.output_parsers import RetryWithErrorOutputParser
10
- from langchain.output_parsers.fix import OutputFixingParser
11
- from langchain_core.exceptions import OutputParserException
12
- from langchain_core.language_models import BaseLanguageModel
13
- from langchain_core.output_parsers import BaseOutputParser
14
- from langchain_core.prompts import ChatPromptTemplate
15
- from langchain_core.runnables import RunnableLambda, RunnableParallel
16
- from openai import BadRequestError, RateLimitError
17
- from text_generation.errors import ValidationError
18
-
19
- from janus.language.naive.registry import CUSTOM_SPLITTERS
20
-
21
- from .converter import Converter, run_if_changed
22
- from .embedding.vectorize import ChromaDBVectorizer
23
- from .language.block import CodeBlock, TranslatedCodeBlock
24
- from .language.combine import ChunkCombiner, Combiner, JsonCombiner
25
- from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
26
- from .llm import load_model
27
- from .llm.model_callbacks import get_model_callback
28
- from .llm.models_info import MODEL_PROMPT_ENGINES
29
- from .parsers.code_parser import CodeParser, GenericParser
30
- from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
31
- from .parsers.eval_parser import EvaluationParser
32
- from .parsers.reqs_parser import RequirementsParser
33
- from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT
34
- from .utils.enums import LANGUAGES
35
- from .utils.logger import create_logger
36
-
37
- log = create_logger(__name__)
38
-
39
-
40
- PARSER_TYPES: set[str] = {"code", "text", "eval", "madlibs", "multidoc", "requirements"}
41
-
42
-
43
- class Translator(Converter):
44
- """A class that translates code from one programming language to another."""
45
-
46
- def __init__(
47
- self,
48
- model: str = "gpt-3.5-turbo-0125",
49
- model_arguments: dict[str, Any] = {},
50
- source_language: str = "fortran",
51
- target_language: str = "python",
52
- target_version: str | None = "3.10",
53
- max_prompts: int = 10,
54
- max_tokens: int | None = None,
55
- prompt_template: str | Path = "simple",
56
- parser_type: str = "code",
57
- db_path: str | None = None,
58
- db_config: dict[str, Any] | None = None,
59
- custom_splitter: str | None = None,
60
- ) -> None:
61
- """Initialize a Translator instance.
62
-
63
- Arguments:
64
- model: The LLM to use for translation. If an OpenAI model, the
65
- `OPENAI_API_KEY` environment variable must be set and the
66
- `OPENAI_ORG_ID` environment variable should be set if needed.
67
- model_arguments: Additional arguments to pass to the LLM constructor.
68
- source_language: The source programming language.
69
- target_language: The target programming language.
70
- target_version: The target version of the target programming language.
71
- max_prompts: The maximum number of prompts to try before giving up.
72
- max_tokens: The maximum number of tokens the model will take in.
73
- If unspecificed, model's default max will be used.
74
- prompt_template: name of prompt template directory
75
- (see janus/prompts/templates) or path to a directory.
76
- parser_type: The type of parser to use for parsing the LLM output. Valid
77
- values are "code" (default), "text", and "eval".
78
- """
79
- self._custom_splitter = custom_splitter
80
- super().__init__(source_language=source_language)
81
-
82
- self._parser_type: str | None
83
- self._model_name: str | None
84
- self._custom_model_arguments: dict[str, Any] | None
85
- self._target_language: str | None
86
- self._target_version: str | None
87
- self._target_glob: str | None
88
- self._prompt_template_name: str | None
89
- self._db_path: str | None
90
- self._db_config: dict[str, Any] | None
91
-
92
- self._llm: BaseLanguageModel | None
93
- self._parser: BaseOutputParser | None
94
- self._combiner: Combiner | None
95
- self._prompt: ChatPromptTemplate | None
96
-
97
- self.max_prompts = max_prompts
98
- self.override_token_limit = False if max_tokens is None else True
99
- self._max_tokens = max_tokens
100
-
101
- self.set_model(model_name=model, **model_arguments)
102
- self.set_parser_type(parser_type=parser_type)
103
- self.set_prompt(prompt_template=prompt_template)
104
- self.set_target_language(
105
- target_language=target_language,
106
- target_version=target_version,
107
- )
108
- self.set_db_path(db_path=db_path)
109
- self.set_db_config(db_config=db_config)
110
-
111
- self._load_parameters()
112
-
113
- def _load_parameters(self) -> None:
114
- self._load_model()
115
- self._load_prompt()
116
- self._load_parser()
117
- self._load_combiner()
118
- self._load_vectorizer()
119
- super()._load_parameters() # will call self._changed_attrs.clear()
120
-
121
- def translate(
122
- self,
123
- input_directory: str | Path,
124
- output_directory: str | Path | None = None,
125
- overwrite: bool = False,
126
- collection_name: str | None = None,
127
- ) -> None:
128
- """Translate code in the input directory from the source language to the target
129
- language, and write the resulting files to the output directory.
130
-
131
- Arguments:
132
- input_directory: The directory containing the code to translate.
133
- output_directory: The directory to write the translated code to.
134
- overwrite: Whether to overwrite existing files (vs skip them)
135
- collection_name: Collection to add to
136
- """
137
- # Convert paths to pathlib Paths if needed
138
- if isinstance(input_directory, str):
139
- input_directory = Path(input_directory)
140
- if isinstance(output_directory, str):
141
- output_directory = Path(output_directory)
142
-
143
- # Make sure the output directory exists
144
- if output_directory is not None and not output_directory.exists():
145
- output_directory.mkdir(parents=True)
146
-
147
- source_suffix = LANGUAGES[self._source_language]["suffix"]
148
- target_suffix = LANGUAGES[self._target_language]["suffix"]
149
-
150
- input_paths = [p for p in input_directory.rglob(self._source_glob)]
151
-
152
- log.info(f"Input directory: {input_directory.absolute()}")
153
- log.info(
154
- f"{self._source_language.capitalize()} '*.{source_suffix}' files: "
155
- f"{len(input_paths)}"
156
- )
157
- log.info(
158
- "Other files (skipped): "
159
- f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
160
- )
161
- if output_directory is not None:
162
- output_paths = [
163
- output_directory
164
- / p.relative_to(input_directory).with_suffix(f".{target_suffix}")
165
- for p in input_paths
166
- ]
167
- in_out_pairs = list(zip(input_paths, output_paths))
168
- if not overwrite:
169
- n_files = len(in_out_pairs)
170
- in_out_pairs = [
171
- (inp, outp) for inp, outp in in_out_pairs if not outp.exists()
172
- ]
173
- log.info(
174
- f"Skipping {n_files - len(in_out_pairs)} existing "
175
- f"'*.{source_suffix}' files"
176
- )
177
- else:
178
- in_out_pairs = [(f, None) for f in input_paths]
179
- log.info(f"Translating {len(in_out_pairs)} '*.{source_suffix}' files")
180
-
181
- # Now, loop through every code block in every file and translate it with an LLM
182
- total_cost = 0.0
183
- for in_path, out_path in in_out_pairs:
184
- # Translate the file, skip it if there's a rate limit error
185
- try:
186
- out_block = self.translate_file(in_path)
187
- total_cost += out_block.total_cost
188
- except RateLimitError:
189
- continue
190
- except OutputParserException as e:
191
- log.error(f"Skipping {in_path.name}, failed to parse output: {e}.")
192
- continue
193
- except BadRequestError as e:
194
- if str(e).startswith("Detected an error in the prompt"):
195
- log.warning("Malformed input, skipping")
196
- continue
197
- raise e
198
- except ValidationError as e:
199
- # Only allow ValidationError to pass if token limit is manually set
200
- if self.override_token_limit:
201
- log.warning(
202
- "Current file and manually set token "
203
- "limit is too large for this model, skipping"
204
- )
205
- continue
206
- raise e
207
- except TokenLimitError:
208
- log.warning("Ran into irreducible node too large for context, skipping")
209
- continue
210
- except EmptyTreeError:
211
- log.warning(
212
- f'Input file "{in_path.name}" has no nodes of interest, skipping'
213
- )
214
- continue
215
- except FileSizeError:
216
- log.warning("Current tile is too large for basic splitter, skipping")
217
- continue
218
-
219
- # Don't attempt to write files for which translation failed
220
- if not out_block.translated:
221
- continue
222
-
223
- # # maybe want target embeddings?
224
- # if self.outputting_requirements():
225
- # filename = str(relative)
226
- # embedding_type = EmbeddingType.REQUIREMENT
227
- # elif self.outputting_summary():
228
- # filename = str(relative)
229
- # embedding_type = EmbeddingType.SUMMARY
230
- # elif self.outputting_pseudocode():
231
- # filename = out_path.name
232
- # embedding_type = EmbeddingType.PSEUDO
233
- # else:
234
- # filename = out_path.name
235
- # embedding_type = EmbeddingType.TARGET
236
- #
237
- # self._embed_nodes_recursively(out_block, embedding_type, filename)
238
-
239
- if collection_name is not None:
240
- self._vectorizer.add_nodes_recursively(
241
- out_block,
242
- collection_name,
243
- in_path.name,
244
- )
245
- # out_text = self.parser.parse_combined_output(out_block.complete_text)
246
- # # Using same id naming convention from vectorize.py
247
- # ids = [str(uuid.uuid3(uuid.NAMESPACE_DNS, out_text))]
248
- # output_collection.upsert(ids=ids, documents=[out_text])
249
-
250
- # Make sure the tree's code has been consolidated at the top level
251
- # before writing to file
252
- self._combiner.combine(out_block)
253
- if out_path is not None and (overwrite or not out_path.exists()):
254
- self._save_to_file(out_block, out_path)
255
-
256
- log.info(f"Total cost: ${total_cost:,.2f}")
257
-
258
- def _split_file(self, file: Path) -> CodeBlock:
259
- filename = file.name
260
- log.info(f"[{filename}] Splitting file")
261
- root = self._splitter.split(file)
262
- log.info(
263
- f"[{filename}] File split into {root.n_descendents:,} blocks, "
264
- f"tree of height {root.height}"
265
- )
266
- log.info(f"[{filename}] Input CodeBlock Structure:\n{root.tree_str()}")
267
- return root
268
-
269
- def translate_file(self, file: Path) -> TranslatedCodeBlock:
270
- """Translate a single file.
271
-
272
- Arguments:
273
- file: Input path to file
274
-
275
- Returns:
276
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
277
- code is not guaranteed to be consolidated. To amend this, run
278
- `Combiner.combine_children` on the block.
279
- """
280
- self._load_parameters()
281
- filename = file.name
282
-
283
- input_block = self._split_file(file)
284
- t0 = time.time()
285
- output_block = self._iterative_translate(input_block)
286
- output_block.processing_time = time.time() - t0
287
- if output_block.translated:
288
- completeness = output_block.translation_completeness
289
- log.info(
290
- f"[{filename}] Translation complete\n"
291
- f" {completeness:.2%} of input successfully translated\n"
292
- f" Total cost: ${output_block.total_cost:,.2f}\n"
293
- f" Total retries: {output_block.total_retries:,d}\n"
294
- f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
295
- )
296
-
297
- else:
298
- log.error(
299
- f"[{filename}] Translation failed\n"
300
- f" Total cost: ${output_block.total_cost:,.2f}\n"
301
- f" Total retries: {output_block.total_retries:,d}\n"
302
- )
303
- return output_block
304
-
305
- def outputting_requirements(self) -> bool:
306
- """Is the output of the translator a requirements file?"""
307
- # expect we will revise system to output more than a single output
308
- # so this is placeholder logic
309
- return self._prompt_template_name == "requirements"
310
-
311
- def outputting_summary(self) -> bool:
312
- """Is the output of the translator a summary documentation?"""
313
- return self._prompt_template_name == "document"
314
-
315
- def outputting_pseudocode(self) -> bool:
316
- """Is the output of the translator pseudocode?"""
317
- # expect we will revise system to output more than a single output
318
- # so this is placeholder logic
319
- return self._prompt_template_name == "pseudocode"
320
-
321
- def _iterative_translate(self, root: CodeBlock) -> TranslatedCodeBlock:
322
- """Translate the passed CodeBlock representing a full file.
323
-
324
- Arguments:
325
- root: A root block representing the top-level block of a file
326
-
327
- Returns:
328
- A `TranslatedCodeBlock`
329
- """
330
- translated_root = TranslatedCodeBlock(root, self._target_language)
331
- last_prog, prog_delta = 0, 0.1
332
- stack = [translated_root]
333
- while stack:
334
- translated_block = stack.pop()
335
-
336
- self._add_translation(translated_block)
337
-
338
- # If translating this block was unsuccessful, don't bother with its
339
- # children (they wouldn't show up in the final text anyway)
340
- if not translated_block.translated:
341
- continue
342
-
343
- stack.extend(translated_block.children)
344
-
345
- progress = translated_root.translation_completeness
346
- if progress - last_prog > prog_delta:
347
- last_prog = int(progress / prog_delta) * prog_delta
348
- log.info(f"[{root.name}] progress: {progress:.2%}")
349
-
350
- return translated_root
351
-
352
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
353
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
354
- `block.original`, setting the relevant fields in the translated block. The
355
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
356
- translates *only* the code for this block, not its children.
357
-
358
- Arguments:
359
- block: An empty `TranslatedCodeBlock`
360
- """
361
- if block.translated:
362
- return
363
-
364
- if block.original.text is None:
365
- block.translated = True
366
- return
367
-
368
- if self._llm is None:
369
- message = (
370
- "Model not configured correctly, cannot translate. Try setting "
371
- "the model"
372
- )
373
- log.error(message)
374
- raise ValueError(message)
375
-
376
- log.debug(f"[{block.name}] Translating...")
377
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
378
-
379
- # Track the cost of translating this block
380
- # TODO: If non-OpenAI models with prices are added, this will need
381
- # to be updated.
382
- with get_model_callback() as cb:
383
- t0 = time.time()
384
- block.text = self._run_chain(block)
385
- block.processing_time = time.time() - t0
386
- block.cost = cb.total_cost
387
- block.retries = max(0, cb.successful_requests - 1)
388
-
389
- block.tokens = self._llm.get_num_tokens(block.text)
390
- block.translated = True
391
-
392
- log.debug(f"[{block.name}] Output code:\n{block.text}")
393
-
394
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
395
- """Run the model with three nested error fixing schemes.
396
- First, try to fix simple formatting errors by giving the model just
397
- the output and the parsing error. After a number of attempts, try
398
- giving the model the output, the parsing error, and the original
399
- input. Again check/retry this output to solve for formatting errors.
400
- If we still haven't succeeded after several attempts, the model may
401
- be getting thrown off by a bad initial output; start from scratch
402
- and try again.
403
-
404
- The number of tries for each layer of this scheme is roughly equal
405
- to the cube root of self.max_retries, so the total calls to the
406
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
407
- """
408
- self._parser.set_reference(block.original)
409
-
410
- # Retries with just the output and the error
411
- n1 = round(self.max_prompts ** (1 / 3))
412
-
413
- # Retries with the input, output, and error
414
- n2 = round((self.max_prompts // n1) ** (1 / 2))
415
-
416
- # Retries with just the input
417
- n3 = math.ceil(self.max_prompts / (n1 * n2))
418
-
419
- fix_format = OutputFixingParser.from_llm(
420
- llm=self._llm,
421
- parser=self._parser,
422
- max_retries=n1,
423
- )
424
- retry = RetryWithErrorOutputParser.from_llm(
425
- llm=self._llm,
426
- parser=fix_format,
427
- max_retries=n2,
428
- )
429
-
430
- completion_chain = self._prompt | self._llm
431
- chain = RunnableParallel(
432
- completion=completion_chain, prompt_value=self._prompt
433
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
434
-
435
- for _ in range(n3):
436
- try:
437
- return chain.invoke({"SOURCE_CODE": block.original.text})
438
- except OutputParserException:
439
- pass
440
-
441
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
442
-
443
- def _save_to_file(self, block: CodeBlock, out_path: Path) -> None:
444
- """Save a file to disk.
445
-
446
- Arguments:
447
- block: The `CodeBlock` to save to a file.
448
- """
449
- # TODO: can't use output fixer and this system for combining output
450
- out_text = self._parser.parse_combined_output(block.complete_text)
451
- out_path.parent.mkdir(parents=True, exist_ok=True)
452
- out_path.write_text(out_text, encoding="utf-8")
453
-
454
- def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
455
- """Validate and set the model name.
456
-
457
- The affected objects will not be updated until translate() is called.
458
-
459
- Arguments:
460
- model_name: The name of the model to use. Valid models are found in
461
- `janus.llm.models_info.MODEL_CONSTRUCTORS`.
462
- custom_arguments: Additional arguments to pass to the model constructor.
463
- """
464
- self._model_name = model_name
465
- self._custom_model_arguments = custom_arguments
466
-
467
- def set_parser_type(self, parser_type: str) -> None:
468
- """Validate and set the parser type.
469
-
470
- The affected objects will not be updated until translate() is called.
471
-
472
- Arguments:
473
- parser_type: The type of parser to use for parsing the LLM output. Valid
474
- values are "code" (default), "text", and "eval".
475
- """
476
- if parser_type not in PARSER_TYPES:
477
- raise ValueError(
478
- f'Unsupported parser type "{parser_type}". Valid types: '
479
- f"{PARSER_TYPES}"
480
- )
481
- self._parser_type = parser_type
482
-
483
- def set_prompt(self, prompt_template: str | Path) -> None:
484
- """Validate and set the prompt template name.
485
-
486
- The affected objects will not be updated until translate() is called.
487
-
488
- Arguments:
489
- prompt_template: name of prompt template directory
490
- (see janus/prompts/templates) or path to a directory.
491
- """
492
- self._prompt_template_name = prompt_template
493
-
494
- def set_target_language(
495
- self, target_language: str, target_version: str | None
496
- ) -> None:
497
- """Validate and set the target language.
498
-
499
- The affected objects will not be updated until translate() is called.
500
-
501
- Arguments:
502
- target_language: The target programming language.
503
- target_version: The target version of the target programming language.
504
- """
505
- target_language = target_language.lower()
506
- if target_language not in LANGUAGES:
507
- raise ValueError(
508
- f"Invalid target language: {target_language}. "
509
- "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
510
- )
511
- self._target_glob = f"**/*.{LANGUAGES[target_language]['suffix']}"
512
- self._target_language = target_language
513
- self._target_version = target_version
514
-
515
- def set_db_path(self, db_path: str) -> None:
516
- self._db_path = db_path
517
-
518
- def set_db_config(self, db_config: dict[str, Any] | None) -> None:
519
- self._db_config = db_config
520
-
521
- @run_if_changed("_model_name", "_custom_model_arguments")
522
- def _load_model(self) -> None:
523
- """Load the model according to this instance's attributes.
524
-
525
- If the relevant fields have not been changed since the last time this method was
526
- called, nothing happens.
527
- """
528
-
529
- # Get default arguments, set custom ones
530
- # model_arguments = deepcopy(MODEL_DEFAULT_ARGUMENTS[self._model_name])
531
- # model_arguments.update(self._custom_model_arguments)
532
-
533
- # Load the model
534
- self._llm, token_limit, self.model_cost = load_model(self._model_name)
535
- # Set the max_tokens to less than half the model's limit to allow for enough
536
- # tokens at output
537
- # Only modify max_tokens if it is not specified by user
538
- if not self.override_token_limit:
539
- self._max_tokens = token_limit // 2.5
540
-
541
- @run_if_changed("_parser_type", "_target_language")
542
- def _load_parser(self) -> None:
543
- """Load the parser according to this instance's attributes.
544
-
545
- If the relevant fields have not been changed since the last time this method was
546
- called, nothing happens.
547
- """
548
- if "text" == self._target_language and self._parser_type != "text":
549
- raise ValueError(
550
- f"Target language ({self._target_language}) suggests target "
551
- f"parser should be 'text', but is '{self._parser_type}'"
552
- )
553
- if (
554
- self._parser_type in {"eval", "multidoc", "madlibs"}
555
- and "json" != self._target_language
556
- ):
557
- raise ValueError(
558
- f"Parser type ({self._parser_type}) suggests target language"
559
- f" should be 'json', but is '{self._target_language}'"
560
- )
561
- if "code" == self._parser_type:
562
- self._parser = CodeParser(language=self._target_language)
563
- elif "eval" == self._parser_type:
564
- self._parser = EvaluationParser()
565
- elif "multidoc" == self._parser_type:
566
- self._parser = MultiDocumentationParser()
567
- elif "madlibs" == self._parser_type:
568
- self._parser = MadlibsDocumentationParser()
569
- elif "text" == self._parser_type:
570
- self._parser = GenericParser()
571
- elif "requirements" == self._parser_type:
572
- self._parser = RequirementsParser()
573
- else:
574
- raise ValueError(
575
- f"Unsupported parser type: {self._parser_type}. Can be: "
576
- f"{PARSER_TYPES}"
577
- )
578
-
579
- @run_if_changed(
580
- "_prompt_template_name",
581
- "_source_language",
582
- "_target_language",
583
- "_target_version",
584
- "_model_name",
585
- )
586
- def _load_prompt(self) -> None:
587
- """Load the prompt according to this instance's attributes.
588
-
589
- If the relevant fields have not been changed since the last time this
590
- method was called, nothing happens.
591
- """
592
- if self._prompt_template_name in SAME_OUTPUT:
593
- if self._target_language != self._source_language:
594
- raise ValueError(
595
- f"Prompt template ({self._prompt_template_name}) suggests "
596
- f"source and target languages should match, but do not "
597
- f"({self._source_language} != {self._target_language})"
598
- )
599
- if self._prompt_template_name in TEXT_OUTPUT and self._target_language != "text":
600
- raise ValueError(
601
- f"Prompt template ({self._prompt_template_name}) suggests target "
602
- f"language should be 'text', but is '{self._target_language}'"
603
- )
604
-
605
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
606
- source_language=self._source_language,
607
- target_language=self._target_language,
608
- target_version=self._target_version,
609
- prompt_template=self._prompt_template_name,
610
- )
611
- self._prompt = prompt_engine.prompt
612
-
613
- @run_if_changed("_db_path")
614
- def _load_vectorizer(self) -> None:
615
- if self._db_path is None:
616
- self._vectorizer = None
617
- return
618
- vectorizer_factory = ChromaDBVectorizer()
619
- self._vectorizer = vectorizer_factory.create_vectorizer(
620
- self._db_path, self._db_config
621
- )
622
-
623
- @run_if_changed(
624
- "_source_language",
625
- "_max_tokens",
626
- "_llm",
627
- "_protected_node_types",
628
- "_prune_node_types",
629
- )
630
- def _load_splitter(self) -> None:
631
- if self._custom_splitter is None:
632
- super()._load_splitter()
633
- else:
634
- kwargs = dict(
635
- max_tokens=self._max_tokens,
636
- model=self._llm,
637
- protected_node_types=self._protected_node_types,
638
- prune_node_types=self._prune_node_types,
639
- )
640
- # TODO: This should be configurable
641
- if self._custom_splitter == "tag":
642
- kwargs["tag"] = "<ITMOD_ALC_SPLIT>"
643
- self._splitter = CUSTOM_SPLITTERS[self._custom_splitter](
644
- language=self._source_language, **kwargs
645
- )
646
-
647
- @run_if_changed("_target_language", "_parser_type")
648
- def _load_combiner(self) -> None:
649
- if self._parser_type == "requirements":
650
- self._combiner = ChunkCombiner()
651
- elif self._target_language == "json":
652
- self._combiner = JsonCombiner()
653
- else:
654
- self._combiner = Combiner()
655
-
656
-
657
- class Documenter(Translator):
658
- def __init__(
659
- self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
660
- ):
661
- kwargs.update(
662
- source_language=source_language,
663
- target_language="text",
664
- target_version=None,
665
- prompt_template="document",
666
- parser_type="text",
667
- )
668
- super().__init__(**kwargs)
669
-
670
- if drop_comments:
671
- comment_node_type = LANGUAGES[source_language].get(
672
- "comment_node_type", "comment"
673
- )
674
- self.set_prune_node_types([comment_node_type])
675
-
676
-
677
- class MultiDocumenter(Documenter):
678
- def __init__(self, **kwargs):
679
- super().__init__(**kwargs)
680
- self.set_prompt("multidocument")
681
- self.set_parser_type("multidoc")
682
- self.set_target_language("json", None)
683
-
684
-
685
- class MadLibsDocumenter(Documenter):
686
- def __init__(
687
- self,
688
- comments_per_request: int | None = None,
689
- **kwargs,
690
- ) -> None:
691
- kwargs.update(drop_comments=False)
692
- super().__init__(**kwargs)
693
- self.set_prompt("document_madlibs")
694
- self.set_parser_type("madlibs")
695
- self.set_target_language("json", None)
696
- self.comments_per_request = comments_per_request
697
-
698
- def _add_translation(self, block: TranslatedCodeBlock):
699
- if block.translated:
700
- return
701
-
702
- if block.original.text is None:
703
- block.translated = True
704
- return
705
-
706
- if self.comments_per_request is None:
707
- return super()._add_translation(block)
708
-
709
- comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
710
- comments = list(
711
- re.finditer(
712
- comment_pattern,
713
- block.original.text,
714
- )
715
- )
716
-
717
- if not comments:
718
- log.info(f"[{block.name}] Skipping commentless block")
719
- block.translated = True
720
- block.text = None
721
- block.complete = True
722
- return
723
-
724
- if len(comments) <= self.comments_per_request:
725
- return super()._add_translation(block)
726
-
727
- comment_group_indices = list(range(0, len(comments), self.comments_per_request))
728
- log.debug(
729
- f"[{block.name}] Block contains more than {self.comments_per_request}"
730
- f" comments, splitting {len(comments)} comments into"
731
- f" {len(comment_group_indices)} groups"
732
- )
733
-
734
- block.processing_time = 0
735
- block.cost = 0
736
- block.retries = 0
737
- obj = {}
738
- for i in range(0, len(comments), self.comments_per_request):
739
- # Split the text into the section containing comments of interest,
740
- # all the text prior to those comments, and all the text after them
741
- working_comments = comments[i : i + self.comments_per_request]
742
- start_idx = working_comments[0].start()
743
- end_idx = working_comments[-1].end()
744
- prefix = block.original.text[:start_idx]
745
- keeper = block.original.text[start_idx:end_idx]
746
- suffix = block.original.text[end_idx:]
747
-
748
- # Strip all comment placeholders outside of the section of interest
749
- prefix = re.sub(comment_pattern, "", prefix)
750
- suffix = re.sub(comment_pattern, "", suffix)
751
-
752
- # Build a new TranslatedBlock using the new working text
753
- working_copy = deepcopy(block.original)
754
- working_copy.text = prefix + keeper + suffix
755
- working_block = TranslatedCodeBlock(working_copy, self._target_language)
756
-
757
- # Run the LLM on the working text
758
- super()._add_translation(working_block)
759
-
760
- # Update metadata to include for all runs
761
- block.retries += working_block.retries
762
- block.cost += working_block.cost
763
- block.processing_time += working_block.processing_time
764
-
765
- # Update the output text to merge this section's output in
766
- out_text = self._parser.parse(working_block.text)
767
- obj.update(json.loads(out_text))
768
-
769
- self._parser.set_reference(block.original)
770
- block.text = self._parser.parse(json.dumps(obj))
771
- block.tokens = self._llm.get_num_tokens(block.text)
772
- block.translated = True
773
-
774
- def _get_obj(
775
- self, block: TranslatedCodeBlock
776
- ) -> dict[str, int | float | dict[str, str]]:
777
- out_text = self._parser.parse_combined_output(block.complete_text)
778
- obj = dict(
779
- retries=block.total_retries,
780
- cost=block.total_cost,
781
- processing_time=block.processing_time,
782
- comments=json.loads(out_text),
783
- )
784
- return obj
785
-
786
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
787
- """Save a file to disk.
788
-
789
- Arguments:
790
- block: The `CodeBlock` to save to a file.
791
- """
792
- obj = self._get_obj(block)
793
- out_path.parent.mkdir(parents=True, exist_ok=True)
794
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
795
-
796
-
797
- class DiagramGenerator(Documenter):
798
- """DiagramGenerator
799
-
800
- A class that translates code from one programming language to a set of diagrams.
801
- """
802
-
803
- def __init__(
804
- self,
805
- model: str = "gpt-3.5-turbo-0125",
806
- model_arguments: dict[str, Any] = {},
807
- source_language: str = "fortran",
808
- max_prompts: int = 10,
809
- db_path: str | None = None,
810
- db_config: dict[str, Any] | None = None,
811
- diagram_type="Activity",
812
- add_documentation=False,
813
- custom_splitter: str | None = None,
814
- ) -> None:
815
- """Initialize the DiagramGenerator class
816
-
817
- Arguments:
818
- model: The LLM to use for translation. If an OpenAI model, the
819
- `OPENAI_API_KEY` environment variable must be set and the
820
- `OPENAI_ORG_ID` environment variable should be set if needed.
821
- model_arguments: Additional arguments to pass to the LLM constructor.
822
- source_language: The source programming language.
823
- max_prompts: The maximum number of prompts to try before giving up.
824
- db_path: path to chroma database
825
- db_config: database configuraiton
826
- diagram_type: type of PLANTUML diagram to generate
827
- """
828
- super().__init__(
829
- model=model,
830
- model_arguments=model_arguments,
831
- source_language=source_language,
832
- max_prompts=max_prompts,
833
- db_path=db_path,
834
- db_config=db_config,
835
- custom_splitter=custom_splitter,
836
- )
837
- self._diagram_type = diagram_type
838
- self._add_documentation = add_documentation
839
- self._documenter = None
840
- self._model = model
841
- self._model_arguments = model_arguments
842
- self._max_prompts = max_prompts
843
- if add_documentation:
844
- self._diagram_prompt_template_name = "diagram_with_documentation"
845
- else:
846
- self._diagram_prompt_template_name = "diagram"
847
- self._load_diagram_prompt_engine()
848
-
849
- def _add_translation(self, block: TranslatedCodeBlock) -> None:
850
- """Given an "empty" `TranslatedCodeBlock`, translate the code represented in
851
- `block.original`, setting the relevant fields in the translated block. The
852
- `TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
853
- translates *only* the code for this block, not its children.
854
-
855
- Arguments:
856
- block: An empty `TranslatedCodeBlock`
857
- """
858
- if block.translated:
859
- return
860
-
861
- if block.original.text is None:
862
- block.translated = True
863
- return
864
-
865
- if self._add_documentation:
866
- documentation_block = deepcopy(block)
867
- super()._add_translation(documentation_block)
868
- if not documentation_block.translated:
869
- message = "Error: unable to produce documentation for code block"
870
- log.message(message)
871
- raise ValueError(message)
872
- documentation = json.loads(documentation_block.text)["docstring"]
873
-
874
- if self._llm is None:
875
- message = (
876
- "Model not configured correctly, cannot translate. Try setting "
877
- "the model"
878
- )
879
- log.error(message)
880
- raise ValueError(message)
881
-
882
- log.debug(f"[{block.name}] Translating...")
883
- log.debug(f"[{block.name}] Input text:\n{block.original.text}")
884
-
885
- self._parser.set_reference(block.original)
886
-
887
- query_and_parse = self.diagram_prompt | self._llm | self._parser
888
-
889
- if self._add_documentation:
890
- block.text = query_and_parse.invoke(
891
- {
892
- "SOURCE_CODE": block.original.text,
893
- "DIAGRAM_TYPE": self._diagram_type,
894
- "DOCUMENTATION": documentation,
895
- }
896
- )
897
- else:
898
- block.text = query_and_parse.invoke(
899
- {
900
- "SOURCE_CODE": block.original.text,
901
- "DIAGRAM_TYPE": self._diagram_type,
902
- }
903
- )
904
- block.tokens = self._llm.get_num_tokens(block.text)
905
- block.translated = True
906
-
907
- log.debug(f"[{block.name}] Output code:\n{block.text}")
908
-
909
- @run_if_changed(
910
- "_diagram_prompt_template_name",
911
- "_source_language",
912
- )
913
- def _load_diagram_prompt_engine(self) -> None:
914
- """Load the prompt engine according to this instance's attributes.
915
-
916
- If the relevant fields have not been changed since the last time this method was
917
- called, nothing happens.
918
- """
919
- if self._diagram_prompt_template_name in SAME_OUTPUT:
920
- if self._target_language != self._source_language:
921
- raise ValueError(
922
- f"Prompt template ({self._prompt_template_name}) suggests "
923
- f"source and target languages should match, but do not "
924
- f"({self._source_language} != {self._target_language})"
925
- )
926
- if (
927
- self._diagram_prompt_template_name in TEXT_OUTPUT
928
- and self._target_language != "text"
929
- ):
930
- raise ValueError(
931
- f"Prompt template ({self._prompt_template_name}) suggests target "
932
- f"language should be 'text', but is '{self._target_language}'"
933
- )
934
-
935
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
936
- source_language=self._source_language,
937
- target_language=self._target_language,
938
- target_version=self._target_version,
939
- prompt_template=self._diagram_prompt_template_name,
940
- )
941
- self.diagram_prompt = self._diagram_prompt_engine.prompt
942
-
943
-
944
- class RequirementsDocumenter(Documenter):
945
- """RequirementsGenerator
946
-
947
- A class that translates code from one programming language to its requirements.
948
- """
949
-
950
- def __init__(self, **kwargs):
951
- super().__init__(**kwargs)
952
- self.set_prompt("requirements")
953
- self.set_target_language("json", None)
954
- self.set_parser_type("requirements")
955
-
956
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
957
- """Save a file to disk.
958
-
959
- Arguments:
960
- block: The `CodeBlock` to save to a file.
961
- """
962
- output_list = list()
963
- # For each chunk of code, get generation metadata, the text of the code,
964
- # and the LLM generated requirements
965
- for child in block.children:
966
- code = child.original.text
967
- requirements = self._parser.parse_combined_output(child.complete_text)
968
- metadata = dict(
969
- retries=child.total_retries,
970
- cost=child.total_cost,
971
- processing_time=child.processing_time,
972
- )
973
- # Put them all in a top level 'output' key
974
- output_list.append(
975
- dict(metadata=metadata, code=code, requirements=requirements)
976
- )
977
- obj = dict(
978
- output=output_list,
979
- )
980
- out_path.parent.mkdir(parents=True, exist_ok=True)
981
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")