janus-llm 3.3.2__tar.gz → 3.4.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (108) hide show
  1. {janus_llm-3.3.2 → janus_llm-3.4.0}/PKG-INFO +1 -1
  2. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/__init__.py +1 -1
  3. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/cli.py +27 -0
  4. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/converter.py +37 -2
  5. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/requirements.py +5 -0
  6. janus_llm-3.4.0/janus/language/alc/alc.py +185 -0
  7. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/block.py +2 -0
  8. janus_llm-3.4.0/janus/language/naive/simple_ast.py +93 -0
  9. {janus_llm-3.3.2 → janus_llm-3.4.0}/pyproject.toml +1 -1
  10. janus_llm-3.3.2/janus/language/alc/alc.py +0 -87
  11. janus_llm-3.3.2/janus/language/naive/simple_ast.py +0 -29
  12. {janus_llm-3.3.2 → janus_llm-3.4.0}/LICENSE +0 -0
  13. {janus_llm-3.3.2 → janus_llm-3.4.0}/README.md +0 -0
  14. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/__main__.py +0 -0
  15. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/_tests/__init__.py +0 -0
  16. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/_tests/conftest.py +0 -0
  17. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/_tests/test_cli.py +0 -0
  18. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/__init__.py +0 -0
  19. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/_tests/__init__.py +0 -0
  20. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/_tests/test_translate.py +0 -0
  21. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/diagram.py +0 -0
  22. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/document.py +0 -0
  23. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/evaluate.py +0 -0
  24. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/converter/translate.py +0 -0
  25. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/__init__.py +0 -0
  26. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/_tests/__init__.py +0 -0
  27. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/_tests/test_collections.py +0 -0
  28. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/_tests/test_database.py +0 -0
  29. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/_tests/test_vectorize.py +0 -0
  30. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/collections.py +0 -0
  31. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/database.py +0 -0
  32. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/embedding_models_info.py +0 -0
  33. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/embedding/vectorize.py +0 -0
  34. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/__init__.py +0 -0
  35. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/_tests/__init__.py +0 -0
  36. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/_tests/test_combine.py +0 -0
  37. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/_tests/test_splitter.py +0 -0
  38. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/alc/__init__.py +0 -0
  39. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/alc/_tests/__init__.py +0 -0
  40. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/alc/_tests/test_alc.py +0 -0
  41. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/binary/__init__.py +0 -0
  42. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/binary/_tests/__init__.py +0 -0
  43. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/binary/_tests/test_binary.py +0 -0
  44. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/binary/binary.py +0 -0
  45. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/binary/reveng/decompile_script.py +0 -0
  46. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/combine.py +0 -0
  47. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/file.py +0 -0
  48. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/mumps/__init__.py +0 -0
  49. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/mumps/_tests/__init__.py +0 -0
  50. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/mumps/_tests/test_mumps.py +0 -0
  51. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/mumps/mumps.py +0 -0
  52. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/mumps/patterns.py +0 -0
  53. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/naive/__init__.py +0 -0
  54. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/naive/basic_splitter.py +0 -0
  55. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/naive/chunk_splitter.py +0 -0
  56. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/naive/registry.py +0 -0
  57. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/naive/tag_splitter.py +0 -0
  58. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/node.py +0 -0
  59. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/splitter.py +0 -0
  60. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/treesitter/__init__.py +0 -0
  61. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/treesitter/_tests/__init__.py +0 -0
  62. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/treesitter/_tests/test_treesitter.py +0 -0
  63. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/language/treesitter/treesitter.py +0 -0
  64. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/llm/__init__.py +0 -0
  65. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/llm/model_callbacks.py +0 -0
  66. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/llm/models_info.py +0 -0
  67. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/__init__.py +0 -0
  68. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/__init__.py +0 -0
  69. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/reference.py +0 -0
  70. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/target.py +0 -0
  71. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_bleu.py +0 -0
  72. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_chrf.py +0 -0
  73. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_file_pairing.py +0 -0
  74. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_llm.py +0 -0
  75. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_reading.py +0 -0
  76. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_rouge_score.py +0 -0
  77. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_similarity_score.py +0 -0
  78. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/_tests/test_treesitter_metrics.py +0 -0
  79. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/bleu.py +0 -0
  80. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/chrf.py +0 -0
  81. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/cli.py +0 -0
  82. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/complexity_metrics.py +0 -0
  83. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/file_pairing.py +0 -0
  84. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/llm_metrics.py +0 -0
  85. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/metric.py +0 -0
  86. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/reading.py +0 -0
  87. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/rouge_score.py +0 -0
  88. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/similarity.py +0 -0
  89. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/metrics/splitting.py +0 -0
  90. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/__init__.py +0 -0
  91. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/_tests/__init__.py +0 -0
  92. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/_tests/test_code_parser.py +0 -0
  93. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/code_parser.py +0 -0
  94. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/doc_parser.py +0 -0
  95. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/eval_parser.py +0 -0
  96. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/refiner_parser.py +0 -0
  97. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/reqs_parser.py +0 -0
  98. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/parsers/uml.py +0 -0
  99. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/prompts/__init__.py +0 -0
  100. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/prompts/prompt.py +0 -0
  101. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/refiners/refiner.py +0 -0
  102. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/__init__.py +0 -0
  103. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/_tests/__init__.py +0 -0
  104. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/_tests/test_logger.py +0 -0
  105. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/_tests/test_progress.py +0 -0
  106. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/enums.py +0 -0
  107. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/logger.py +0 -0
  108. {janus_llm-3.3.2 → janus_llm-3.4.0}/janus/utils/progress.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.3.2
3
+ Version: 3.4.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.3.2"
8
+ __version__ = "3.4.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
@@ -200,6 +200,14 @@ def translate(
200
200
  help="Whether to overwrite existing files in the output directory",
201
201
  ),
202
202
  ] = False,
203
+ skip_context: Annotated[
204
+ bool,
205
+ typer.Option(
206
+ "--skip-context",
207
+ help="Prompts will include any context information associated with source"
208
+ " code blocks, unless this option is specified",
209
+ ),
210
+ ] = False,
203
211
  temp: Annotated[
204
212
  float,
205
213
  typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
@@ -265,6 +273,7 @@ def translate(
265
273
  db_path=db_loc,
266
274
  db_config=collections_config,
267
275
  splitter_type=splitter_type,
276
+ skip_context=skip_context,
268
277
  )
269
278
  translator.translate(input_dir, output_dir, overwrite, collection)
270
279
 
@@ -322,6 +331,14 @@ def document(
322
331
  help="Whether to overwrite existing files in the output directory",
323
332
  ),
324
333
  ] = False,
334
+ skip_context: Annotated[
335
+ bool,
336
+ typer.Option(
337
+ "--skip-context",
338
+ help="Prompts will include any context information associated with source"
339
+ " code blocks, unless this option is specified",
340
+ ),
341
+ ] = False,
325
342
  doc_mode: Annotated[
326
343
  str,
327
344
  typer.Option(
@@ -390,6 +407,7 @@ def document(
390
407
  db_path=db_loc,
391
408
  db_config=collections_config,
392
409
  splitter_type=splitter_type,
410
+ skip_context=skip_context,
393
411
  )
394
412
  if doc_mode == "madlibs":
395
413
  documenter = MadLibsDocumenter(
@@ -458,6 +476,14 @@ def diagram(
458
476
  help="Whether to overwrite existing files in the output directory",
459
477
  ),
460
478
  ] = False,
479
+ skip_context: Annotated[
480
+ bool,
481
+ typer.Option(
482
+ "--skip-context",
483
+ help="Prompts will include any context information associated with source"
484
+ " code blocks, unless this option is specified",
485
+ ),
486
+ ] = False,
461
487
  temperature: Annotated[
462
488
  float,
463
489
  typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
@@ -507,6 +533,7 @@ def diagram(
507
533
  diagram_type=diagram_type,
508
534
  add_documentation=add_documentation,
509
535
  splitter_type=splitter_type,
536
+ skip_context=skip_context,
510
537
  )
511
538
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
512
539
 
@@ -3,13 +3,13 @@ import json
3
3
  import math
4
4
  import time
5
5
  from pathlib import Path
6
- from typing import Any
6
+ from typing import Any, List, Optional, Tuple
7
7
 
8
8
  from langchain.output_parsers import RetryWithErrorOutputParser
9
9
  from langchain_core.exceptions import OutputParserException
10
10
  from langchain_core.language_models import BaseLanguageModel
11
11
  from langchain_core.output_parsers import BaseOutputParser
12
- from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
13
13
  from langchain_core.runnables import RunnableLambda, RunnableParallel
14
14
  from openai import BadRequestError, RateLimitError
15
15
  from pydantic import ValidationError
@@ -77,6 +77,7 @@ class Converter:
77
77
  prune_node_types: tuple[str, ...] = (),
78
78
  splitter_type: str = "file",
79
79
  refiner_type: str = "basic",
80
+ skip_context: bool = False,
80
81
  ) -> None:
81
82
  """Initialize a Converter instance.
82
83
 
@@ -142,6 +143,8 @@ class Converter:
142
143
  self.set_db_path(db_path=db_path)
143
144
  self.set_db_config(db_config=db_config)
144
145
 
146
+ self.skip_context = skip_context
147
+
145
148
  # Child class must call this. Should we enforce somehow?
146
149
  # self._load_parameters()
147
150
 
@@ -602,6 +605,9 @@ class Converter:
602
605
 
603
606
  # Retries with just the input
604
607
  n3 = math.ceil(self.max_prompts / (n1 * n2))
608
+ # Make replacements in the prompt
609
+ if not self.skip_context:
610
+ self._make_prompt_additions(block)
605
611
 
606
612
  refine_output = RefinerParser(
607
613
  parser=self._parser,
@@ -648,6 +654,35 @@ class Converter:
648
654
  output=output,
649
655
  )
650
656
 
657
+ @staticmethod
658
+ def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
659
+ """Get a list of strings to append to the prompt.
660
+
661
+ Arguments:
662
+ block: The `TranslatedCodeBlock` to save to a file.
663
+ """
664
+ return [(key, item) for key, item in block.context_tags.items()]
665
+
666
+ def _make_prompt_additions(self, block: CodeBlock):
667
+ # Prepare the additional context to prepend
668
+ additional_context = "".join(
669
+ [
670
+ f"{context_tag}: {context}\n"
671
+ for context_tag, context in self._get_prompt_additions(block)
672
+ ]
673
+ )
674
+
675
+ # Iterate through existing messages to find and update the system message
676
+ for i, message in enumerate(self._prompt.messages):
677
+ if isinstance(message, SystemMessagePromptTemplate):
678
+ # Prepend the additional context to the system message
679
+ updated_system_message = SystemMessagePromptTemplate.from_template(
680
+ additional_context + message.prompt.template
681
+ )
682
+ # Directly modify the message in the list
683
+ self._prompt.messages[i] = updated_system_message
684
+ break # Assuming there's only one system message to update
685
+
651
686
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
652
687
  """Save a file to disk.
653
688
 
@@ -22,6 +22,11 @@ class RequirementsDocumenter(Documenter):
22
22
  self._combiner = ChunkCombiner()
23
23
  self._parser = RequirementsParser()
24
24
 
25
+ @staticmethod
26
+ def get_prompt_replacements(block) -> dict[str, str]:
27
+ prompt_replacements: dict[str, str] = {"SOURCE_CODE": block.original.text}
28
+ return prompt_replacements
29
+
25
30
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
26
31
  """Save a file to disk.
27
32
 
@@ -0,0 +1,185 @@
1
+ import re
2
+ from typing import Optional
3
+
4
+ from langchain.schema.language_model import BaseLanguageModel
5
+
6
+ from janus.language.block import CodeBlock
7
+ from janus.language.combine import Combiner
8
+ from janus.language.node import NodeType
9
+ from janus.language.treesitter import TreeSitterSplitter
10
+ from janus.utils.logger import create_logger
11
+
12
+ log = create_logger(__name__)
13
+
14
+
15
+ class AlcCombiner(Combiner):
16
+ """A class that combines code blocks into ALC files."""
17
+
18
+ def __init__(self) -> None:
19
+ """Initialize a AlcCombiner instance."""
20
+ super().__init__("ibmhlasm")
21
+
22
+
23
+ class AlcSplitter(TreeSitterSplitter):
24
+ """A class for splitting ALC code into functional blocks to prompt
25
+ with for transcoding.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: None | BaseLanguageModel = None,
31
+ max_tokens: int = 4096,
32
+ protected_node_types: tuple[str, ...] = (),
33
+ prune_node_types: tuple[str, ...] = (),
34
+ prune_unprotected: bool = False,
35
+ ):
36
+ """Initialize a AlcSplitter instance.
37
+
38
+ Arguments:
39
+ max_tokens: The maximum number of tokens supported by the model
40
+ """
41
+ super().__init__(
42
+ language="ibmhlasm",
43
+ model=model,
44
+ max_tokens=max_tokens,
45
+ protected_node_types=protected_node_types,
46
+ prune_node_types=prune_node_types,
47
+ prune_unprotected=prune_unprotected,
48
+ )
49
+
50
+ def _get_ast(self, code: str) -> CodeBlock:
51
+ root = super()._get_ast(code)
52
+
53
+ # Current treesitter implementation does not nest csects and dsects
54
+ # The loop below nests nodes following csect/dsect instructions into
55
+ # the children of that instruction
56
+ sect_types = {"csect_instruction", "dsect_instruction"}
57
+ queue: list[CodeBlock] = [root]
58
+ while queue:
59
+ block = queue.pop(0)
60
+
61
+ # Search this children for csects and dsects. Create a list of groups
62
+ # where each group is a csect or dsect, starting with the csect/dsect
63
+ # instruction and containing all the subsequent nodes up until the
64
+ # next csect or dsect instruction
65
+ sects: list[list[CodeBlock]] = [[]]
66
+ for c in block.children:
67
+ if c.node_type == "csect_instruction":
68
+ c.context_tags["alc_section"] = "CSECT"
69
+ sects.append([c])
70
+ elif c.node_type == "dsect_instruction":
71
+ c.context_tags["alc_section"] = "DSECT"
72
+ sects.append([c])
73
+ else:
74
+ sects[-1].append(c)
75
+
76
+ sects = [s for s in sects if s]
77
+
78
+ # Restructure the tree, making the head of each group the parent
79
+ # of all the remaining nodes in that group
80
+ if len(sects) > 1:
81
+ block.children = []
82
+ for sect in sects:
83
+ if sect[0].node_type in sect_types:
84
+ sect_node = self.merge_nodes(sect)
85
+ sect_node.children = sect
86
+ sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
87
+ block.children.append(sect_node)
88
+ else:
89
+ block.children.extend(sect)
90
+
91
+ # Push the children onto the queue
92
+ queue.extend(block.children)
93
+
94
+ return root
95
+
96
+
97
+ class AlcListingSplitter(AlcSplitter):
98
+ """A class for splitting ALC listing code into functional blocks to
99
+ prompt with for transcoding.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ model: None | BaseLanguageModel = None,
105
+ max_tokens: int = 4096,
106
+ protected_node_types: tuple[str, ...] = (),
107
+ prune_node_types: tuple[str, ...] = (),
108
+ prune_unprotected: bool = False,
109
+ ):
110
+ """Initialize a AlcSplitter instance.
111
+
112
+
113
+ Arguments:
114
+ max_tokens: The maximum number of tokens supported by the model
115
+ """
116
+ # The string to mark the end of the listing header
117
+ self.header_indicator_str: str = (
118
+ "Loc Object Code Addr1 Addr2 Stmt Source Statement"
119
+ )
120
+ # How many characters to trim from the right side to remove the address column
121
+ self.address_column_chars: int = 10
122
+ # The string to mark the end of the left margin
123
+ self.left_margin_indicator_str: str = "Stmt"
124
+ super().__init__(
125
+ model=model,
126
+ max_tokens=max_tokens,
127
+ protected_node_types=protected_node_types,
128
+ prune_node_types=prune_node_types,
129
+ prune_unprotected=prune_unprotected,
130
+ )
131
+
132
+ def _get_ast(self, code: str) -> CodeBlock:
133
+ active_usings = self.get_active_usings(code)
134
+ code = self.preproccess_assembly(code)
135
+ ast: CodeBlock = super()._get_ast(code)
136
+ ast.context_tags["active_usings"] = active_usings
137
+ return ast
138
+
139
+ def preproccess_assembly(self, code: str) -> str:
140
+ """Remove non-essential lines from an assembly snippet"""
141
+
142
+ lines = code.splitlines()
143
+ lines = self.strip_header_and_left(lines)
144
+ lines = self.strip_addresses(lines)
145
+ return "".join(str(line) for line in lines)
146
+
147
+ def get_active_usings(self, code: str) -> Optional[str]:
148
+ """Look for 'active usings' in the ALC listing header"""
149
+ lines = code.splitlines()
150
+ for line in lines:
151
+ if "Active Usings:" in line:
152
+ return line.split("Active Usings:")[1]
153
+ return None
154
+
155
+ def strip_header_and_left(
156
+ self,
157
+ lines: list[str],
158
+ ) -> list[str]:
159
+ """Remove the header and the left panel from the assembly sample"""
160
+
161
+ esd_regex = re.compile(f".*{self.header_indicator_str}.*")
162
+
163
+ header_end_index: int = [
164
+ i for i, item in enumerate(lines) if re.search(esd_regex, item)
165
+ ][0]
166
+
167
+ left_content_end_column = lines[header_end_index].find(
168
+ self.left_margin_indicator_str
169
+ )
170
+ hori_output_lines = lines[(header_end_index + 1) :]
171
+
172
+ left_output_lines = [
173
+ line[left_content_end_column + 5 :] for line in hori_output_lines
174
+ ]
175
+ return left_output_lines
176
+
177
+ def strip_addresses(self, lines: list[str]) -> list[str]:
178
+ """Strip the addresses which run down the right side of the assembly snippet"""
179
+
180
+ stripped_lines = [line[: -self.address_column_chars] for line in lines]
181
+ return stripped_lines
182
+
183
+ def strip_footer(self, lines: list[str]):
184
+ """Strip the footer from the assembly snippet"""
185
+ return NotImplementedError
@@ -45,6 +45,7 @@ class CodeBlock:
45
45
  children: list[ForwardRef("CodeBlock")],
46
46
  embedding_id: Optional[str] = None,
47
47
  affixes: Tuple[str, str] = ("", ""),
48
+ context_tags: dict[str, str] = {},
48
49
  ) -> None:
49
50
  self.id: Hashable = id
50
51
  self.name: Optional[str] = name
@@ -59,6 +60,7 @@ class CodeBlock:
59
60
  self.children: list[ForwardRef("CodeBlock")] = sorted(children)
60
61
  self.embedding_id: Optional[str] = embedding_id
61
62
  self.affixes: Tuple[str, str] = affixes
63
+ self.context_tags: dict[str, str] = context_tags
62
64
 
63
65
  self.complete = True
64
66
  self.omit_prefix = True
@@ -0,0 +1,93 @@
1
+ from janus.language.alc.alc import AlcListingSplitter, AlcSplitter
2
+ from janus.language.mumps.mumps import MumpsSplitter
3
+ from janus.language.naive.registry import register_splitter
4
+ from janus.language.splitter import Splitter
5
+ from janus.language.treesitter import TreeSitterSplitter
6
+ from janus.utils.enums import LANGUAGES
7
+ from janus.utils.logger import create_logger
8
+
9
+ log = create_logger(__name__)
10
+
11
+
12
+ @register_splitter("ast-flex")
13
+ def get_flexible_ast(language: str, **kwargs) -> Splitter:
14
+ """Get a flexible AST splitter for the given language.
15
+
16
+ Arguments:
17
+ language: The language to get the splitter for.
18
+
19
+ Returns:
20
+ A flexible AST splitter for the given language.
21
+ """
22
+ if language == "ibmhlasm":
23
+ return AlcSplitter(**kwargs)
24
+ elif language == "mumps":
25
+ return MumpsSplitter(**kwargs)
26
+ else:
27
+ return TreeSitterSplitter(language=language, **kwargs)
28
+
29
+
30
+ @register_splitter("ast-strict")
31
+ def get_strict_ast(language: str, **kwargs) -> Splitter:
32
+ """Get a strict AST splitter for the given language.
33
+
34
+ The strict splitter will only return nodes that are of a functional type.
35
+
36
+ Arguments:
37
+ language: The language to get the splitter for.
38
+
39
+ Returns:
40
+ A strict AST splitter for the given language.
41
+ """
42
+ kwargs.update(
43
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
44
+ prune_unprotected=True,
45
+ )
46
+ if language == "ibmhlasm":
47
+ return AlcSplitter(**kwargs)
48
+ elif language == "mumps":
49
+ return MumpsSplitter(**kwargs)
50
+ else:
51
+ return TreeSitterSplitter(language=language, **kwargs)
52
+
53
+
54
+ @register_splitter("ast-strict-listing")
55
+ def get_strict_listing_ast(language: str, **kwargs) -> Splitter:
56
+ """Get a strict AST splitter for the given language. This splitter is intended for
57
+ use with IBM HLASM.
58
+
59
+ The strict splitter will only return nodes that are of a functional type.
60
+
61
+ Arguments:
62
+ language: The language to get the splitter for.
63
+
64
+ Returns:
65
+ A strict AST splitter for the given language.
66
+ """
67
+ kwargs.update(
68
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
69
+ prune_unprotected=True,
70
+ )
71
+ if language == "ibmhlasm":
72
+ return AlcListingSplitter(**kwargs)
73
+ else:
74
+ log.warning("Listing splitter is only intended for use with IBMHLASM!")
75
+ return TreeSitterSplitter(language=language, **kwargs)
76
+
77
+
78
+ @register_splitter("ast-flex-listing")
79
+ def get_flexible_listing_ast(language: str, **kwargs) -> Splitter:
80
+ """Get a flexible AST splitter for the given language. This splitter is intended for
81
+ use with IBM HLASM.
82
+
83
+ Arguments:
84
+ language: The language to get the splitter for.
85
+
86
+ Returns:
87
+ A flexible AST splitter for the given language.
88
+ """
89
+ if language == "ibmhlasm":
90
+ return AlcListingSplitter(**kwargs)
91
+ else:
92
+ log.warning("Listing splitter is only intended for use with IBMHLASM!")
93
+ return TreeSitterSplitter(language=language, **kwargs)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "janus-llm"
3
- version = "3.3.2"
3
+ version = "3.4.0"
4
4
  description = "A transcoding library using LLMs."
5
5
  authors = ["Michael Doyle <mdoyle@mitre.org>", "Chris Glasz <cglasz@mitre.org>",
6
6
  "Chris Tohline <ctohline@mitre.org>", "William Macke <wmacke@mitre.org>",
@@ -1,87 +0,0 @@
1
- from langchain.schema.language_model import BaseLanguageModel
2
-
3
- from janus.language.block import CodeBlock
4
- from janus.language.combine import Combiner
5
- from janus.language.node import NodeType
6
- from janus.language.treesitter import TreeSitterSplitter
7
- from janus.utils.logger import create_logger
8
-
9
- log = create_logger(__name__)
10
-
11
-
12
- class AlcCombiner(Combiner):
13
- """A class that combines code blocks into ALC files."""
14
-
15
- def __init__(self) -> None:
16
- """Initialize a AlcCombiner instance."""
17
- super().__init__("ibmhlasm")
18
-
19
-
20
- class AlcSplitter(TreeSitterSplitter):
21
- """A class for splitting ALC code into functional blocks to prompt
22
- with for transcoding.
23
- """
24
-
25
- def __init__(
26
- self,
27
- model: None | BaseLanguageModel = None,
28
- max_tokens: int = 4096,
29
- protected_node_types: tuple[str, ...] = (),
30
- prune_node_types: tuple[str, ...] = (),
31
- prune_unprotected: bool = False,
32
- ):
33
- """Initialize a AlcSplitter instance.
34
-
35
- Arguments:
36
- max_tokens: The maximum number of tokens supported by the model
37
- """
38
- super().__init__(
39
- language="ibmhlasm",
40
- model=model,
41
- max_tokens=max_tokens,
42
- protected_node_types=protected_node_types,
43
- prune_node_types=prune_node_types,
44
- prune_unprotected=prune_unprotected,
45
- )
46
-
47
- def _get_ast(self, code: str) -> CodeBlock:
48
- root = super()._get_ast(code)
49
-
50
- # Current treesitter implementation does not nest csects and dsects
51
- # The loop below nests nodes following csect/dsect instructions into
52
- # the children of that instruction
53
- sect_types = {"csect_instruction", "dsect_instruction"}
54
- queue: list[CodeBlock] = [root]
55
- while queue:
56
- block = queue.pop(0)
57
-
58
- # Search this children for csects and dsects. Create a list of groups
59
- # where each group is a csect or dsect, starting with the csect/dsect
60
- # instruction and containing all the subsequent nodes up until the
61
- # next csect or dsect instruction
62
- sects: list[list[CodeBlock]] = [[]]
63
- for c in block.children:
64
- if c.node_type in sect_types:
65
- sects.append([c])
66
- else:
67
- sects[-1].append(c)
68
-
69
- sects = [s for s in sects if s]
70
-
71
- # Restructure the tree, making the head of each group the parent
72
- # of all the remaining nodes in that group
73
- if len(sects) > 1:
74
- block.children = []
75
- for sect in sects:
76
- if sect[0].node_type in sect_types:
77
- sect_node = self.merge_nodes(sect)
78
- sect_node.children = sect
79
- sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
80
- block.children.append(sect_node)
81
- else:
82
- block.children.extend(sect)
83
-
84
- # Push the children onto the queue
85
- queue.extend(block.children)
86
-
87
- return root
@@ -1,29 +0,0 @@
1
- from janus.language.alc.alc import AlcSplitter
2
- from janus.language.mumps.mumps import MumpsSplitter
3
- from janus.language.naive.registry import register_splitter
4
- from janus.language.treesitter import TreeSitterSplitter
5
- from janus.utils.enums import LANGUAGES
6
-
7
-
8
- @register_splitter("ast-flex")
9
- def get_flexible_ast(language: str, **kwargs):
10
- if language == "ibmhlasm":
11
- return AlcSplitter(**kwargs)
12
- elif language == "mumps":
13
- return MumpsSplitter(**kwargs)
14
- else:
15
- return TreeSitterSplitter(language=language, **kwargs)
16
-
17
-
18
- @register_splitter("ast-strict")
19
- def get_strict_ast(language: str, **kwargs):
20
- kwargs.update(
21
- protected_node_types=LANGUAGES[language]["functional_node_types"],
22
- prune_unprotected=True,
23
- )
24
- if language == "ibmhlasm":
25
- return AlcSplitter(**kwargs)
26
- elif language == "mumps":
27
- return MumpsSplitter(**kwargs)
28
- else:
29
- return TreeSitterSplitter(language=language, **kwargs)
File without changes
File without changes
File without changes