janus-llm 3.3.2__tar.gz → 3.4.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (108) hide show
  1. {janus_llm-3.3.2 → janus_llm-3.4.1}/PKG-INFO +1 -1
  2. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/__init__.py +1 -1
  3. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/cli.py +51 -0
  4. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/converter.py +63 -23
  5. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/requirements.py +5 -0
  6. janus_llm-3.4.1/janus/language/alc/alc.py +185 -0
  7. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/block.py +2 -0
  8. janus_llm-3.4.1/janus/language/naive/simple_ast.py +93 -0
  9. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/refiner_parser.py +3 -1
  10. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/refiners/refiner.py +17 -7
  11. {janus_llm-3.3.2 → janus_llm-3.4.1}/pyproject.toml +1 -1
  12. janus_llm-3.3.2/janus/language/alc/alc.py +0 -87
  13. janus_llm-3.3.2/janus/language/naive/simple_ast.py +0 -29
  14. {janus_llm-3.3.2 → janus_llm-3.4.1}/LICENSE +0 -0
  15. {janus_llm-3.3.2 → janus_llm-3.4.1}/README.md +0 -0
  16. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/__main__.py +0 -0
  17. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/_tests/__init__.py +0 -0
  18. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/_tests/conftest.py +0 -0
  19. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/_tests/test_cli.py +0 -0
  20. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/__init__.py +0 -0
  21. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/_tests/__init__.py +0 -0
  22. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/_tests/test_translate.py +0 -0
  23. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/diagram.py +0 -0
  24. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/document.py +0 -0
  25. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/evaluate.py +0 -0
  26. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/converter/translate.py +0 -0
  27. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/__init__.py +0 -0
  28. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/_tests/__init__.py +0 -0
  29. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/_tests/test_collections.py +0 -0
  30. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/_tests/test_database.py +0 -0
  31. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/_tests/test_vectorize.py +0 -0
  32. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/collections.py +0 -0
  33. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/database.py +0 -0
  34. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/embedding_models_info.py +0 -0
  35. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/embedding/vectorize.py +0 -0
  36. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/__init__.py +0 -0
  37. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/_tests/__init__.py +0 -0
  38. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/_tests/test_combine.py +0 -0
  39. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/_tests/test_splitter.py +0 -0
  40. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/alc/__init__.py +0 -0
  41. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/alc/_tests/__init__.py +0 -0
  42. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/alc/_tests/test_alc.py +0 -0
  43. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/binary/__init__.py +0 -0
  44. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/binary/_tests/__init__.py +0 -0
  45. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/binary/_tests/test_binary.py +0 -0
  46. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/binary/binary.py +0 -0
  47. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/binary/reveng/decompile_script.py +0 -0
  48. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/combine.py +0 -0
  49. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/file.py +0 -0
  50. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/mumps/__init__.py +0 -0
  51. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/mumps/_tests/__init__.py +0 -0
  52. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/mumps/_tests/test_mumps.py +0 -0
  53. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/mumps/mumps.py +0 -0
  54. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/mumps/patterns.py +0 -0
  55. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/naive/__init__.py +0 -0
  56. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/naive/basic_splitter.py +0 -0
  57. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/naive/chunk_splitter.py +0 -0
  58. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/naive/registry.py +0 -0
  59. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/naive/tag_splitter.py +0 -0
  60. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/node.py +0 -0
  61. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/splitter.py +0 -0
  62. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/treesitter/__init__.py +0 -0
  63. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/treesitter/_tests/__init__.py +0 -0
  64. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/treesitter/_tests/test_treesitter.py +0 -0
  65. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/language/treesitter/treesitter.py +0 -0
  66. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/llm/__init__.py +0 -0
  67. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/llm/model_callbacks.py +0 -0
  68. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/llm/models_info.py +0 -0
  69. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/__init__.py +0 -0
  70. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/__init__.py +0 -0
  71. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/reference.py +0 -0
  72. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/target.py +0 -0
  73. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_bleu.py +0 -0
  74. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_chrf.py +0 -0
  75. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_file_pairing.py +0 -0
  76. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_llm.py +0 -0
  77. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_reading.py +0 -0
  78. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_rouge_score.py +0 -0
  79. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_similarity_score.py +0 -0
  80. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/_tests/test_treesitter_metrics.py +0 -0
  81. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/bleu.py +0 -0
  82. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/chrf.py +0 -0
  83. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/cli.py +0 -0
  84. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/complexity_metrics.py +0 -0
  85. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/file_pairing.py +0 -0
  86. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/llm_metrics.py +0 -0
  87. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/metric.py +0 -0
  88. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/reading.py +0 -0
  89. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/rouge_score.py +0 -0
  90. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/similarity.py +0 -0
  91. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/metrics/splitting.py +0 -0
  92. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/__init__.py +0 -0
  93. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/_tests/__init__.py +0 -0
  94. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/_tests/test_code_parser.py +0 -0
  95. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/code_parser.py +0 -0
  96. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/doc_parser.py +0 -0
  97. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/eval_parser.py +0 -0
  98. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/reqs_parser.py +0 -0
  99. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/parsers/uml.py +0 -0
  100. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/prompts/__init__.py +0 -0
  101. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/prompts/prompt.py +0 -0
  102. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/__init__.py +0 -0
  103. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/_tests/__init__.py +0 -0
  104. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/_tests/test_logger.py +0 -0
  105. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/_tests/test_progress.py +0 -0
  106. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/enums.py +0 -0
  107. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/logger.py +0 -0
  108. {janus_llm-3.3.2 → janus_llm-3.4.1}/janus/utils/progress.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.3.2
3
+ Version: 3.4.1
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.3.2"
8
+ __version__ = "3.4.1"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
@@ -200,6 +200,14 @@ def translate(
200
200
  help="Whether to overwrite existing files in the output directory",
201
201
  ),
202
202
  ] = False,
203
+ skip_context: Annotated[
204
+ bool,
205
+ typer.Option(
206
+ "--skip-context",
207
+ help="Prompts will include any context information associated with source"
208
+ " code blocks, unless this option is specified",
209
+ ),
210
+ ] = False,
203
211
  temp: Annotated[
204
212
  float,
205
213
  typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
@@ -240,6 +248,13 @@ def translate(
240
248
  "If unspecificed, model's default max will be used.",
241
249
  ),
242
250
  ] = None,
251
+ skip_refiner: Annotated[
252
+ bool,
253
+ typer.Option(
254
+ "--skip-refiner",
255
+ help="Whether to skip the refiner for generating output",
256
+ ),
257
+ ] = True,
243
258
  ):
244
259
  try:
245
260
  target_language, target_version = target_lang.split("-")
@@ -265,6 +280,8 @@ def translate(
265
280
  db_path=db_loc,
266
281
  db_config=collections_config,
267
282
  splitter_type=splitter_type,
283
+ skip_context=skip_context,
284
+ skip_refiner=skip_refiner,
268
285
  )
269
286
  translator.translate(input_dir, output_dir, overwrite, collection)
270
287
 
@@ -322,6 +339,14 @@ def document(
322
339
  help="Whether to overwrite existing files in the output directory",
323
340
  ),
324
341
  ] = False,
342
+ skip_context: Annotated[
343
+ bool,
344
+ typer.Option(
345
+ "--skip-context",
346
+ help="Prompts will include any context information associated with source"
347
+ " code blocks, unless this option is specified",
348
+ ),
349
+ ] = False,
325
350
  doc_mode: Annotated[
326
351
  str,
327
352
  typer.Option(
@@ -378,6 +403,13 @@ def document(
378
403
  "If unspecificed, model's default max will be used.",
379
404
  ),
380
405
  ] = None,
406
+ skip_refiner: Annotated[
407
+ bool,
408
+ typer.Option(
409
+ "--skip-refiner",
410
+ help="Whether to skip the refiner for generating output",
411
+ ),
412
+ ] = True,
381
413
  ):
382
414
  model_arguments = dict(temperature=temperature)
383
415
  collections_config = get_collections_config()
@@ -390,6 +422,8 @@ def document(
390
422
  db_path=db_loc,
391
423
  db_config=collections_config,
392
424
  splitter_type=splitter_type,
425
+ skip_refiner=skip_refiner,
426
+ skip_context=skip_context,
393
427
  )
394
428
  if doc_mode == "madlibs":
395
429
  documenter = MadLibsDocumenter(
@@ -458,6 +492,14 @@ def diagram(
458
492
  help="Whether to overwrite existing files in the output directory",
459
493
  ),
460
494
  ] = False,
495
+ skip_context: Annotated[
496
+ bool,
497
+ typer.Option(
498
+ "--skip-context",
499
+ help="Prompts will include any context information associated with source"
500
+ " code blocks, unless this option is specified",
501
+ ),
502
+ ] = False,
461
503
  temperature: Annotated[
462
504
  float,
463
505
  typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
@@ -494,6 +536,13 @@ def diagram(
494
536
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
495
537
  ),
496
538
  ] = "file",
539
+ skip_refiner: Annotated[
540
+ bool,
541
+ typer.Option(
542
+ "--skip-refiner",
543
+ help="Whether to skip the refiner for generating output",
544
+ ),
545
+ ] = True,
497
546
  ):
498
547
  model_arguments = dict(temperature=temperature)
499
548
  collections_config = get_collections_config()
@@ -507,6 +556,8 @@ def diagram(
507
556
  diagram_type=diagram_type,
508
557
  add_documentation=add_documentation,
509
558
  splitter_type=splitter_type,
559
+ skip_refiner=skip_refiner,
560
+ skip_context=skip_context,
510
561
  )
511
562
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
512
563
 
@@ -1,15 +1,14 @@
1
1
  import functools
2
2
  import json
3
- import math
4
3
  import time
5
4
  from pathlib import Path
6
- from typing import Any
5
+ from typing import Any, List, Optional, Tuple
7
6
 
8
7
  from langchain.output_parsers import RetryWithErrorOutputParser
9
8
  from langchain_core.exceptions import OutputParserException
10
9
  from langchain_core.language_models import BaseLanguageModel
11
10
  from langchain_core.output_parsers import BaseOutputParser
12
- from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
13
12
  from langchain_core.runnables import RunnableLambda, RunnableParallel
14
13
  from openai import BadRequestError, RateLimitError
15
14
  from pydantic import ValidationError
@@ -77,6 +76,8 @@ class Converter:
77
76
  prune_node_types: tuple[str, ...] = (),
78
77
  splitter_type: str = "file",
79
78
  refiner_type: str = "basic",
79
+ skip_refiner: bool = True,
80
+ skip_context: bool = False,
80
81
  ) -> None:
81
82
  """Initialize a Converter instance.
82
83
 
@@ -97,6 +98,8 @@ class Converter:
97
98
  splitter_type: The type of splitter to use. Valid values are `"file"`,
98
99
  `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
100
  refiner_type: The type of refiner to use. Valid values are `"basic"`.
101
+ skip_refiner: Whether to skip the refiner.
102
+ skip_context: Whether to skip adding context to the prompt.
100
103
  """
101
104
  self._changed_attrs: set = set()
102
105
 
@@ -132,6 +135,8 @@ class Converter:
132
135
  self._refiner_type: str
133
136
  self._refiner: Refiner
134
137
 
138
+ self.skip_refiner = skip_refiner
139
+
135
140
  self.set_splitter(splitter_type=splitter_type)
136
141
  self.set_refiner(refiner_type=refiner_type)
137
142
  self.set_model(model_name=model, **model_arguments)
@@ -142,6 +147,8 @@ class Converter:
142
147
  self.set_db_path(db_path=db_path)
143
148
  self.set_db_config(db_config=db_config)
144
149
 
150
+ self.skip_context = skip_context
151
+
145
152
  # Child class must call this. Should we enforce somehow?
146
153
  # self._load_parameters()
147
154
 
@@ -290,7 +297,7 @@ class Converter:
290
297
  """
291
298
  if self._refiner_type == "basic":
292
299
  self._refiner = BasicRefiner(
293
- "basic_refinement", self._model_name, self._source_language
300
+ "basic_refinement", self._model_id, self._source_language
294
301
  )
295
302
  else:
296
303
  raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
@@ -595,37 +602,41 @@ class Converter:
595
602
  self._parser.set_reference(block.original)
596
603
 
597
604
  # Retries with just the output and the error
598
- n1 = round(self.max_prompts ** (1 / 3))
605
+ n1 = round(self.max_prompts ** (1 / 2))
599
606
 
600
607
  # Retries with the input, output, and error
601
- n2 = round((self.max_prompts // n1) ** (1 / 2))
608
+ n2 = round(self.max_prompts // n1)
602
609
 
603
610
  # Retries with just the input
604
- n3 = math.ceil(self.max_prompts / (n1 * n2))
605
-
606
- refine_output = RefinerParser(
607
- parser=self._parser,
608
- initial_prompt=self._prompt.format(**{"SOURCE_CODE": block.original.text}),
609
- refiner=self._refiner,
610
- max_retries=n1,
611
- llm=self._llm,
612
- )
613
- retry = RetryWithErrorOutputParser.from_llm(
614
- llm=self._llm,
615
- parser=refine_output,
616
- max_retries=n2,
617
- )
611
+ if not self.skip_context:
612
+ self._make_prompt_additions(block)
613
+ if not self.skip_refiner: # Make replacements in the prompt
614
+ refine_output = RefinerParser(
615
+ parser=self._parser,
616
+ initial_prompt=self._prompt.format(
617
+ **{"SOURCE_CODE": block.original.text}
618
+ ),
619
+ refiner=self._refiner,
620
+ max_retries=n1,
621
+ llm=self._llm,
622
+ )
623
+ else:
624
+ refine_output = RetryWithErrorOutputParser.from_llm(
625
+ llm=self._llm,
626
+ parser=self._parser,
627
+ max_retries=n1,
628
+ )
618
629
  completion_chain = self._prompt | self._llm
619
630
  chain = RunnableParallel(
620
631
  completion=completion_chain, prompt_value=self._prompt
621
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
622
- for _ in range(n3):
632
+ ) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
633
+ for _ in range(n2):
623
634
  try:
624
635
  return chain.invoke({"SOURCE_CODE": block.original.text})
625
636
  except OutputParserException:
626
637
  pass
627
638
 
628
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
639
+ raise OutputParserException(f"Failed to parse after {n1*n2} retries")
629
640
 
630
641
  def _get_output_obj(
631
642
  self, block: TranslatedCodeBlock
@@ -648,6 +659,35 @@ class Converter:
648
659
  output=output,
649
660
  )
650
661
 
662
+ @staticmethod
663
+ def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
664
+ """Get a list of strings to append to the prompt.
665
+
666
+ Arguments:
667
+ block: The `TranslatedCodeBlock` to save to a file.
668
+ """
669
+ return [(key, item) for key, item in block.context_tags.items()]
670
+
671
+ def _make_prompt_additions(self, block: CodeBlock):
672
+ # Prepare the additional context to prepend
673
+ additional_context = "".join(
674
+ [
675
+ f"{context_tag}: {context}\n"
676
+ for context_tag, context in self._get_prompt_additions(block)
677
+ ]
678
+ )
679
+
680
+ # Iterate through existing messages to find and update the system message
681
+ for i, message in enumerate(self._prompt.messages):
682
+ if isinstance(message, SystemMessagePromptTemplate):
683
+ # Prepend the additional context to the system message
684
+ updated_system_message = SystemMessagePromptTemplate.from_template(
685
+ additional_context + message.prompt.template
686
+ )
687
+ # Directly modify the message in the list
688
+ self._prompt.messages[i] = updated_system_message
689
+ break # Assuming there's only one system message to update
690
+
651
691
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
652
692
  """Save a file to disk.
653
693
 
@@ -22,6 +22,11 @@ class RequirementsDocumenter(Documenter):
22
22
  self._combiner = ChunkCombiner()
23
23
  self._parser = RequirementsParser()
24
24
 
25
+ @staticmethod
26
+ def get_prompt_replacements(block) -> dict[str, str]:
27
+ prompt_replacements: dict[str, str] = {"SOURCE_CODE": block.original.text}
28
+ return prompt_replacements
29
+
25
30
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
26
31
  """Save a file to disk.
27
32
 
@@ -0,0 +1,185 @@
1
+ import re
2
+ from typing import Optional
3
+
4
+ from langchain.schema.language_model import BaseLanguageModel
5
+
6
+ from janus.language.block import CodeBlock
7
+ from janus.language.combine import Combiner
8
+ from janus.language.node import NodeType
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.logger import create_logger
11
+
12
+ log = create_logger(__name__)
13
+
14
+
15
+ class AlcCombiner(Combiner):
16
+ """A class that combines code blocks into ALC files."""
17
+
18
+ def __init__(self) -> None:
19
+ """Initialize a AlcCombiner instance."""
20
+ super().__init__("ibmhlasm")
21
+
22
+
23
+ class AlcSplitter(TreeSitterSplitter):
24
+ """A class for splitting ALC code into functional blocks to prompt
25
+ with for transcoding.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: None | BaseLanguageModel = None,
31
+ max_tokens: int = 4096,
32
+ protected_node_types: tuple[str, ...] = (),
33
+ prune_node_types: tuple[str, ...] = (),
34
+ prune_unprotected: bool = False,
35
+ ):
36
+ """Initialize a AlcSplitter instance.
37
+
38
+ Arguments:
39
+ max_tokens: The maximum number of tokens supported by the model
40
+ """
41
+ super().__init__(
42
+ language="ibmhlasm",
43
+ model=model,
44
+ max_tokens=max_tokens,
45
+ protected_node_types=protected_node_types,
46
+ prune_node_types=prune_node_types,
47
+ prune_unprotected=prune_unprotected,
48
+ )
49
+
50
+ def _get_ast(self, code: str) -> CodeBlock:
51
+ root = super()._get_ast(code)
52
+
53
+ # Current treesitter implementation does not nest csects and dsects
54
+ # The loop below nests nodes following csect/dsect instructions into
55
+ # the children of that instruction
56
+ sect_types = {"csect_instruction", "dsect_instruction"}
57
+ queue: list[CodeBlock] = [root]
58
+ while queue:
59
+ block = queue.pop(0)
60
+
61
+ # Search this children for csects and dsects. Create a list of groups
62
+ # where each group is a csect or dsect, starting with the csect/dsect
63
+ # instruction and containing all the subsequent nodes up until the
64
+ # next csect or dsect instruction
65
+ sects: list[list[CodeBlock]] = [[]]
66
+ for c in block.children:
67
+ if c.node_type == "csect_instruction":
68
+ c.context_tags["alc_section"] = "CSECT"
69
+ sects.append([c])
70
+ elif c.node_type == "dsect_instruction":
71
+ c.context_tags["alc_section"] = "DSECT"
72
+ sects.append([c])
73
+ else:
74
+ sects[-1].append(c)
75
+
76
+ sects = [s for s in sects if s]
77
+
78
+ # Restructure the tree, making the head of each group the parent
79
+ # of all the remaining nodes in that group
80
+ if len(sects) > 1:
81
+ block.children = []
82
+ for sect in sects:
83
+ if sect[0].node_type in sect_types:
84
+ sect_node = self.merge_nodes(sect)
85
+ sect_node.children = sect
86
+ sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
87
+ block.children.append(sect_node)
88
+ else:
89
+ block.children.extend(sect)
90
+
91
+ # Push the children onto the queue
92
+ queue.extend(block.children)
93
+
94
+ return root
95
+
96
+
97
+ class AlcListingSplitter(AlcSplitter):
98
+ """A class for splitting ALC listing code into functional blocks to
99
+ prompt with for transcoding.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ model: None | BaseLanguageModel = None,
105
+ max_tokens: int = 4096,
106
+ protected_node_types: tuple[str, ...] = (),
107
+ prune_node_types: tuple[str, ...] = (),
108
+ prune_unprotected: bool = False,
109
+ ):
110
+ """Initialize a AlcSplitter instance.
111
+
112
+
113
+ Arguments:
114
+ max_tokens: The maximum number of tokens supported by the model
115
+ """
116
+ # The string to mark the end of the listing header
117
+ self.header_indicator_str: str = (
118
+ "Loc Object Code Addr1 Addr2 Stmt Source Statement"
119
+ )
120
+ # How many characters to trim from the right side to remove the address column
121
+ self.address_column_chars: int = 10
122
+ # The string to mark the end of the left margin
123
+ self.left_margin_indicator_str: str = "Stmt"
124
+ super().__init__(
125
+ model=model,
126
+ max_tokens=max_tokens,
127
+ protected_node_types=protected_node_types,
128
+ prune_node_types=prune_node_types,
129
+ prune_unprotected=prune_unprotected,
130
+ )
131
+
132
+ def _get_ast(self, code: str) -> CodeBlock:
133
+ active_usings = self.get_active_usings(code)
134
+ code = self.preproccess_assembly(code)
135
+ ast: CodeBlock = super()._get_ast(code)
136
+ ast.context_tags["active_usings"] = active_usings
137
+ return ast
138
+
139
+ def preproccess_assembly(self, code: str) -> str:
140
+ """Remove non-essential lines from an assembly snippet"""
141
+
142
+ lines = code.splitlines()
143
+ lines = self.strip_header_and_left(lines)
144
+ lines = self.strip_addresses(lines)
145
+ return "".join(str(line) for line in lines)
146
+
147
+ def get_active_usings(self, code: str) -> Optional[str]:
148
+ """Look for 'active usings' in the ALC listing header"""
149
+ lines = code.splitlines()
150
+ for line in lines:
151
+ if "Active Usings:" in line:
152
+ return line.split("Active Usings:")[1]
153
+ return None
154
+
155
+ def strip_header_and_left(
156
+ self,
157
+ lines: list[str],
158
+ ) -> list[str]:
159
+ """Remove the header and the left panel from the assembly sample"""
160
+
161
+ esd_regex = re.compile(f".*{self.header_indicator_str}.*")
162
+
163
+ header_end_index: int = [
164
+ i for i, item in enumerate(lines) if re.search(esd_regex, item)
165
+ ][0]
166
+
167
+ left_content_end_column = lines[header_end_index].find(
168
+ self.left_margin_indicator_str
169
+ )
170
+ hori_output_lines = lines[(header_end_index + 1) :]
171
+
172
+ left_output_lines = [
173
+ line[left_content_end_column + 5 :] for line in hori_output_lines
174
+ ]
175
+ return left_output_lines
176
+
177
+ def strip_addresses(self, lines: list[str]) -> list[str]:
178
+ """Strip the addresses which run down the right side of the assembly snippet"""
179
+
180
+ stripped_lines = [line[: -self.address_column_chars] for line in lines]
181
+ return stripped_lines
182
+
183
+ def strip_footer(self, lines: list[str]):
184
+ """Strip the footer from the assembly snippet"""
185
+ return NotImplementedError
@@ -45,6 +45,7 @@ class CodeBlock:
45
45
  children: list[ForwardRef("CodeBlock")],
46
46
  embedding_id: Optional[str] = None,
47
47
  affixes: Tuple[str, str] = ("", ""),
48
+ context_tags: dict[str, str] = {},
48
49
  ) -> None:
49
50
  self.id: Hashable = id
50
51
  self.name: Optional[str] = name
@@ -59,6 +60,7 @@ class CodeBlock:
59
60
  self.children: list[ForwardRef("CodeBlock")] = sorted(children)
60
61
  self.embedding_id: Optional[str] = embedding_id
61
62
  self.affixes: Tuple[str, str] = affixes
63
+ self.context_tags: dict[str, str] = context_tags
62
64
 
63
65
  self.complete = True
64
66
  self.omit_prefix = True
@@ -0,0 +1,93 @@
1
+ from janus.language.alc.alc import AlcListingSplitter, AlcSplitter
2
+ from janus.language.mumps.mumps import MumpsSplitter
3
+ from janus.language.naive.registry import register_splitter
4
+ from janus.language.splitter import Splitter
5
+ from janus.language.treesitter import TreeSitterSplitter
6
+ from janus.utils.enums import LANGUAGES
7
+ from janus.utils.logger import create_logger
8
+
9
+ log = create_logger(__name__)
10
+
11
+
12
+ @register_splitter("ast-flex")
13
+ def get_flexible_ast(language: str, **kwargs) -> Splitter:
14
+ """Get a flexible AST splitter for the given language.
15
+
16
+ Arguments:
17
+ language: The language to get the splitter for.
18
+
19
+ Returns:
20
+ A flexible AST splitter for the given language.
21
+ """
22
+ if language == "ibmhlasm":
23
+ return AlcSplitter(**kwargs)
24
+ elif language == "mumps":
25
+ return MumpsSplitter(**kwargs)
26
+ else:
27
+ return TreeSitterSplitter(language=language, **kwargs)
28
+
29
+
30
+ @register_splitter("ast-strict")
31
+ def get_strict_ast(language: str, **kwargs) -> Splitter:
32
+ """Get a strict AST splitter for the given language.
33
+
34
+ The strict splitter will only return nodes that are of a functional type.
35
+
36
+ Arguments:
37
+ language: The language to get the splitter for.
38
+
39
+ Returns:
40
+ A strict AST splitter for the given language.
41
+ """
42
+ kwargs.update(
43
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
44
+ prune_unprotected=True,
45
+ )
46
+ if language == "ibmhlasm":
47
+ return AlcSplitter(**kwargs)
48
+ elif language == "mumps":
49
+ return MumpsSplitter(**kwargs)
50
+ else:
51
+ return TreeSitterSplitter(language=language, **kwargs)
52
+
53
+
54
+ @register_splitter("ast-strict-listing")
55
+ def get_strict_listing_ast(language: str, **kwargs) -> Splitter:
56
+ """Get a strict AST splitter for the given language. This splitter is intended for
57
+ use with IBM HLASM.
58
+
59
+ The strict splitter will only return nodes that are of a functional type.
60
+
61
+ Arguments:
62
+ language: The language to get the splitter for.
63
+
64
+ Returns:
65
+ A strict AST splitter for the given language.
66
+ """
67
+ kwargs.update(
68
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
69
+ prune_unprotected=True,
70
+ )
71
+ if language == "ibmhlasm":
72
+ return AlcListingSplitter(**kwargs)
73
+ else:
74
+ log.warning("Listing splitter is only intended for use with IBMHLASM!")
75
+ return TreeSitterSplitter(language=language, **kwargs)
76
+
77
+
78
+ @register_splitter("ast-flex-listing")
79
+ def get_flexible_listing_ast(language: str, **kwargs) -> Splitter:
80
+ """Get a flexible AST splitter for the given language. This splitter is intended for
81
+ use with IBM HLASM.
82
+
83
+ Arguments:
84
+ language: The language to get the splitter for.
85
+
86
+ Returns:
87
+ A flexible AST splitter for the given language.
88
+ """
89
+ if language == "ibmhlasm":
90
+ return AlcListingSplitter(**kwargs)
91
+ else:
92
+ log.warning("Listing splitter is only intended for use with IBMHLASM!")
93
+ return TreeSitterSplitter(language=language, **kwargs)
@@ -40,7 +40,9 @@ class RefinerParser(BaseOutputParser):
40
40
  return self.parser.parse(text)
41
41
  except OutputParserException as oe:
42
42
  err = str(oe)
43
- new_prompt, prompt_arguments = self.refiner.refine(last_prompt, text, err)
43
+ new_prompt, prompt_arguments = self.refiner.refine(
44
+ self.initial_prompt, last_prompt, text, err
45
+ )
44
46
  new_chain = new_prompt | self.llm
45
47
  text = new_chain.invoke(prompt_arguments)
46
48
  last_prompt = new_prompt.format(**prompt_arguments)
@@ -5,7 +5,12 @@ from janus.llm.models_info import MODEL_PROMPT_ENGINES
5
5
 
6
6
  class Refiner:
7
7
  def refine(
8
- self, original_prompt: str, original_output: str, errors: str, **kwargs
8
+ self,
9
+ original_prompt: str,
10
+ previous_prompt: str,
11
+ previous_output: str,
12
+ errors: str,
13
+ **kwargs,
9
14
  ) -> tuple[ChatPromptTemplate, dict[str, str]]:
10
15
  """Creates a new prompt based on feedback from original results
11
16
 
@@ -24,22 +29,27 @@ class BasicRefiner(Refiner):
24
29
  def __init__(
25
30
  self,
26
31
  prompt_name: str,
27
- model_name: str,
32
+ model_id: str,
28
33
  source_language: str,
29
34
  ) -> None:
30
35
  """Basic refiner, asks llm to fix output of previous prompt given errors
31
36
 
32
37
  Arguments:
33
38
  prompt_name: refinement prompt name to use
34
- model_name: name of llm to use
39
+ model_id: ID of the llm to use. Found in models_info.py
35
40
  source_language: source_langauge to use
36
41
  """
37
42
  self._prompt_name = prompt_name
38
- self._model_name = model_name
43
+ self._model_id = model_id
39
44
  self._source_language = source_language
40
45
 
41
46
  def refine(
42
- self, original_prompt: str, original_output: str, errors: str, **kwargs
47
+ self,
48
+ original_prompt: str,
49
+ previous_prompt: str,
50
+ previous_output: str,
51
+ errors: str,
52
+ **kwargs,
43
53
  ) -> tuple[ChatPromptTemplate, dict[str, str]]:
44
54
  """Creates a new prompt based on feedback from original results
45
55
 
@@ -51,13 +61,13 @@ class BasicRefiner(Refiner):
51
61
  Returns:
52
62
  Tuple of new prompt and prompt arguments
53
63
  """
54
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
64
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
55
65
  prompt_template=self._prompt_name,
56
66
  source_language=self._source_language,
57
67
  )
58
68
  prompt_arguments = {
59
69
  "ORIGINAL_PROMPT": original_prompt,
60
- "OUTPUT": original_output,
70
+ "OUTPUT": previous_output,
61
71
  "ERRORS": errors,
62
72
  }
63
73
  return prompt_engine.prompt, prompt_arguments
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "janus-llm"
3
- version = "3.3.2"
3
+ version = "3.4.1"
4
4
  description = "A transcoding library using LLMs."
5
5
  authors = ["Michael Doyle <mdoyle@mitre.org>", "Chris Glasz <cglasz@mitre.org>",
6
6
  "Chris Tohline <ctohline@mitre.org>", "William Macke <wmacke@mitre.org>",
@@ -1,87 +0,0 @@
1
- from langchain.schema.language_model import BaseLanguageModel
2
-
3
- from janus.language.block import CodeBlock
4
- from janus.language.combine import Combiner
5
- from janus.language.node import NodeType
6
- from janus.language.treesitter import TreeSitterSplitter
7
- from janus.utils.logger import create_logger
8
-
9
- log = create_logger(__name__)
10
-
11
-
12
- class AlcCombiner(Combiner):
13
- """A class that combines code blocks into ALC files."""
14
-
15
- def __init__(self) -> None:
16
- """Initialize a AlcCombiner instance."""
17
- super().__init__("ibmhlasm")
18
-
19
-
20
- class AlcSplitter(TreeSitterSplitter):
21
- """A class for splitting ALC code into functional blocks to prompt
22
- with for transcoding.
23
- """
24
-
25
- def __init__(
26
- self,
27
- model: None | BaseLanguageModel = None,
28
- max_tokens: int = 4096,
29
- protected_node_types: tuple[str, ...] = (),
30
- prune_node_types: tuple[str, ...] = (),
31
- prune_unprotected: bool = False,
32
- ):
33
- """Initialize a AlcSplitter instance.
34
-
35
- Arguments:
36
- max_tokens: The maximum number of tokens supported by the model
37
- """
38
- super().__init__(
39
- language="ibmhlasm",
40
- model=model,
41
- max_tokens=max_tokens,
42
- protected_node_types=protected_node_types,
43
- prune_node_types=prune_node_types,
44
- prune_unprotected=prune_unprotected,
45
- )
46
-
47
- def _get_ast(self, code: str) -> CodeBlock:
48
- root = super()._get_ast(code)
49
-
50
- # Current treesitter implementation does not nest csects and dsects
51
- # The loop below nests nodes following csect/dsect instructions into
52
- # the children of that instruction
53
- sect_types = {"csect_instruction", "dsect_instruction"}
54
- queue: list[CodeBlock] = [root]
55
- while queue:
56
- block = queue.pop(0)
57
-
58
- # Search this children for csects and dsects. Create a list of groups
59
- # where each group is a csect or dsect, starting with the csect/dsect
60
- # instruction and containing all the subsequent nodes up until the
61
- # next csect or dsect instruction
62
- sects: list[list[CodeBlock]] = [[]]
63
- for c in block.children:
64
- if c.node_type in sect_types:
65
- sects.append([c])
66
- else:
67
- sects[-1].append(c)
68
-
69
- sects = [s for s in sects if s]
70
-
71
- # Restructure the tree, making the head of each group the parent
72
- # of all the remaining nodes in that group
73
- if len(sects) > 1:
74
- block.children = []
75
- for sect in sects:
76
- if sect[0].node_type in sect_types:
77
- sect_node = self.merge_nodes(sect)
78
- sect_node.children = sect
79
- sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
80
- block.children.append(sect_node)
81
- else:
82
- block.children.extend(sect)
83
-
84
- # Push the children onto the queue
85
- queue.extend(block.children)
86
-
87
- return root
@@ -1,29 +0,0 @@
1
- from janus.language.alc.alc import AlcSplitter
2
- from janus.language.mumps.mumps import MumpsSplitter
3
- from janus.language.naive.registry import register_splitter
4
- from janus.language.treesitter import TreeSitterSplitter
5
- from janus.utils.enums import LANGUAGES
6
-
7
-
8
- @register_splitter("ast-flex")
9
- def get_flexible_ast(language: str, **kwargs):
10
- if language == "ibmhlasm":
11
- return AlcSplitter(**kwargs)
12
- elif language == "mumps":
13
- return MumpsSplitter(**kwargs)
14
- else:
15
- return TreeSitterSplitter(language=language, **kwargs)
16
-
17
-
18
- @register_splitter("ast-strict")
19
- def get_strict_ast(language: str, **kwargs):
20
- kwargs.update(
21
- protected_node_types=LANGUAGES[language]["functional_node_types"],
22
- prune_unprotected=True,
23
- )
24
- if language == "ibmhlasm":
25
- return AlcSplitter(**kwargs)
26
- elif language == "mumps":
27
- return MumpsSplitter(**kwargs)
28
- else:
29
- return TreeSitterSplitter(language=language, **kwargs)
File without changes
File without changes
File without changes