janus-llm 2.0.2__tar.gz → 2.1.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (98) hide show
  1. {janus_llm-2.0.2 → janus_llm-2.1.0}/PKG-INFO +1 -1
  2. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/__init__.py +1 -1
  3. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/cli.py +1 -2
  4. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/converter.py +3 -0
  5. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/_tests/test_collections.py +2 -2
  6. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/_tests/test_splitter.py +1 -1
  7. janus_llm-2.1.0/janus/language/alc/__init__.py +1 -0
  8. janus_llm-2.1.0/janus/language/alc/_tests/test_alc.py +28 -0
  9. janus_llm-2.1.0/janus/language/alc/alc.py +87 -0
  10. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/block.py +3 -1
  11. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/mumps/mumps.py +2 -3
  12. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/naive/__init__.py +1 -1
  13. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/naive/basic_splitter.py +4 -4
  14. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/naive/chunk_splitter.py +4 -4
  15. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/naive/registry.py +1 -1
  16. janus_llm-2.1.0/janus/language/naive/simple_ast.py +29 -0
  17. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/naive/tag_splitter.py +4 -4
  18. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/splitter.py +10 -4
  19. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/treesitter/treesitter.py +6 -7
  20. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/llm/model_callbacks.py +1 -1
  21. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/llm/models_info.py +2 -3
  22. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_llm.py +2 -3
  23. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_rouge_score.py +1 -1
  24. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_similarity_score.py +1 -1
  25. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/complexity_metrics.py +3 -4
  26. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/metric.py +3 -4
  27. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/prompts/prompt.py +34 -0
  28. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/translate.py +11 -5
  29. janus_llm-2.1.0/janus/utils/_tests/__init__.py +0 -0
  30. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/utils/enums.py +6 -5
  31. {janus_llm-2.0.2 → janus_llm-2.1.0}/pyproject.toml +1 -1
  32. janus_llm-2.0.2/janus/language/naive/simple_ast.py +0 -18
  33. {janus_llm-2.0.2 → janus_llm-2.1.0}/LICENSE +0 -0
  34. {janus_llm-2.0.2 → janus_llm-2.1.0}/README.md +0 -0
  35. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/__main__.py +0 -0
  36. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/_tests/__init__.py +0 -0
  37. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/_tests/conftest.py +0 -0
  38. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/_tests/test_cli.py +0 -0
  39. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/_tests/test_translate.py +0 -0
  40. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/__init__.py +0 -0
  41. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/_tests/__init__.py +0 -0
  42. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/_tests/test_database.py +0 -0
  43. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/_tests/test_vectorize.py +0 -0
  44. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/collections.py +0 -0
  45. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/database.py +0 -0
  46. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/embedding_models_info.py +0 -0
  47. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/embedding/vectorize.py +0 -0
  48. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/__init__.py +0 -0
  49. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/_tests/__init__.py +0 -0
  50. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/_tests/test_combine.py +0 -0
  51. {janus_llm-2.0.2/janus/language/binary → janus_llm-2.1.0/janus/language/alc}/_tests/__init__.py +0 -0
  52. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/binary/__init__.py +0 -0
  53. {janus_llm-2.0.2/janus/language/mumps → janus_llm-2.1.0/janus/language/binary}/_tests/__init__.py +0 -0
  54. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/binary/_tests/test_binary.py +0 -0
  55. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/binary/binary.py +0 -0
  56. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/binary/reveng/decompile_script.py +0 -0
  57. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/combine.py +0 -0
  58. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/file.py +0 -0
  59. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/mumps/__init__.py +0 -0
  60. {janus_llm-2.0.2/janus/language/treesitter → janus_llm-2.1.0/janus/language/mumps}/_tests/__init__.py +0 -0
  61. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/mumps/_tests/test_mumps.py +0 -0
  62. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/mumps/patterns.py +0 -0
  63. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/node.py +0 -0
  64. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/treesitter/__init__.py +0 -0
  65. {janus_llm-2.0.2/janus/metrics → janus_llm-2.1.0/janus/language/treesitter}/_tests/__init__.py +0 -0
  66. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/language/treesitter/_tests/test_treesitter.py +0 -0
  67. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/llm/__init__.py +0 -0
  68. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/__init__.py +0 -0
  69. {janus_llm-2.0.2/janus/parsers → janus_llm-2.1.0/janus/metrics/_tests}/__init__.py +0 -0
  70. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/reference.py +0 -0
  71. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/target.py +0 -0
  72. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_bleu.py +0 -0
  73. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_chrf.py +0 -0
  74. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_file_pairing.py +0 -0
  75. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_reading.py +0 -0
  76. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/_tests/test_treesitter_metrics.py +0 -0
  77. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/bleu.py +0 -0
  78. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/chrf.py +0 -0
  79. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/cli.py +0 -0
  80. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/file_pairing.py +0 -0
  81. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/llm_metrics.py +0 -0
  82. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/reading.py +0 -0
  83. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/rouge_score.py +0 -0
  84. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/similarity.py +0 -0
  85. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/metrics/splitting.py +0 -0
  86. {janus_llm-2.0.2/janus/parsers/_tests → janus_llm-2.1.0/janus/parsers}/__init__.py +0 -0
  87. {janus_llm-2.0.2/janus/prompts → janus_llm-2.1.0/janus/parsers/_tests}/__init__.py +0 -0
  88. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/parsers/_tests/test_code_parser.py +0 -0
  89. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/parsers/code_parser.py +0 -0
  90. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/parsers/doc_parser.py +0 -0
  91. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/parsers/eval_parser.py +0 -0
  92. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/parsers/reqs_parser.py +0 -0
  93. {janus_llm-2.0.2/janus/utils → janus_llm-2.1.0/janus/prompts}/__init__.py +0 -0
  94. {janus_llm-2.0.2/janus/utils/_tests → janus_llm-2.1.0/janus/utils}/__init__.py +0 -0
  95. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/utils/_tests/test_logger.py +0 -0
  96. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/utils/_tests/test_progress.py +0 -0
  97. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/utils/logger.py +0 -0
  98. {janus_llm-2.0.2 → janus_llm-2.1.0}/janus/utils/progress.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 2.0.2
3
+ Version: 2.1.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from .metrics import * # noqa: F403
6
6
  from .translate import Translator
7
7
 
8
- __version__ = "2.0.2"
8
+ __version__ = "2.1.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
@@ -12,8 +12,6 @@ from rich.console import Console
12
12
  from rich.prompt import Confirm
13
13
  from typing_extensions import Annotated
14
14
 
15
- from janus.language.naive.registry import CUSTOM_SPLITTERS
16
-
17
15
  from .embedding.collections import Collections
18
16
  from .embedding.database import ChromaEmbeddingDatabase
19
17
  from .embedding.embedding_models_info import (
@@ -25,6 +23,7 @@ from .embedding.embedding_models_info import (
25
23
  from .embedding.vectorize import ChromaDBVectorizer
26
24
  from .language.binary import BinarySplitter
27
25
  from .language.mumps import MumpsSplitter
26
+ from .language.naive.registry import CUSTOM_SPLITTERS
28
27
  from .language.treesitter import TreeSitterSplitter
29
28
  from .llm.model_callbacks import COST_PER_1K_TOKENS
30
29
  from .llm.models_info import MODEL_CONFIG_DIR, MODEL_TYPE_CONSTRUCTORS, TOKEN_LIMITS
@@ -3,6 +3,7 @@ from typing import Any
3
3
 
4
4
  from langchain.schema.language_model import BaseLanguageModel
5
5
 
6
+ from .language.alc.alc import AlcSplitter
6
7
  from .language.binary import BinarySplitter
7
8
  from .language.mumps import MumpsSplitter
8
9
  from .language.splitter import Splitter
@@ -152,6 +153,8 @@ class Converter:
152
153
  if self._source_language in CUSTOM_SPLITTERS:
153
154
  if self._source_language == "mumps":
154
155
  self._splitter = MumpsSplitter(**kwargs)
156
+ elif self._source_language == "ibmhlasm":
157
+ self._splitter = AlcSplitter(**kwargs)
155
158
  elif self._source_language == "binary":
156
159
  self._splitter = BinarySplitter(**kwargs)
157
160
  else:
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
4
4
 
5
5
  import pytest
6
6
 
7
- from janus.embedding.collections import Collections
8
- from janus.utils.enums import EmbeddingType
7
+ from ...utils.enums import EmbeddingType
8
+ from ..collections import Collections
9
9
 
10
10
 
11
11
  class TestCollections(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.language.splitter import Splitter
3
+ from ..splitter import Splitter
4
4
 
5
5
 
6
6
  class TestSplitter(unittest.TestCase):
@@ -0,0 +1 @@
1
+ from .alc import AlcCombiner, AlcSplitter
@@ -0,0 +1,28 @@
1
+ import unittest
2
+ from pathlib import Path
3
+
4
+ from ....llm import load_model
5
+ from ...combine import Combiner
6
+ from ..alc import AlcSplitter
7
+
8
+
9
+ class TestAlcSplitter(unittest.TestCase):
10
+ """Tests for the Splitter class."""
11
+
12
+ def setUp(self):
13
+ """Set up the tests."""
14
+ model_name = "gpt-3.5-turbo-0125"
15
+ llm, _, _ = load_model(model_name)
16
+ self.splitter = AlcSplitter(model=llm)
17
+ self.combiner = Combiner(language="ibmhlasm")
18
+ self.test_file = Path("janus/language/alc/_tests/alc.asm")
19
+
20
+ def test_split(self):
21
+ """Test the split method."""
22
+ tree_root = self.splitter.split(self.test_file)
23
+ self.assertEqual(tree_root.n_descendents, 34)
24
+ self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
+ self.assertFalse(tree_root.complete)
26
+ self.combiner.combine_children(tree_root)
27
+ self.assertTrue(tree_root.complete)
28
+ self.assertEqual(tree_root.complete_text, self.test_file.read_text())
@@ -0,0 +1,87 @@
1
+ from langchain.schema.language_model import BaseLanguageModel
2
+
3
+ from ...utils.logger import create_logger
4
+ from ..block import CodeBlock
5
+ from ..combine import Combiner
6
+ from ..node import NodeType
7
+ from ..treesitter import TreeSitterSplitter
8
+
9
+ log = create_logger(__name__)
10
+
11
+
12
+ class AlcCombiner(Combiner):
13
+ """A class that combines code blocks into ALC files."""
14
+
15
+ def __init__(self) -> None:
16
+ """Initialize a AlcCombiner instance."""
17
+ super().__init__("ibmhlasm")
18
+
19
+
20
+ class AlcSplitter(TreeSitterSplitter):
21
+ """A class for splitting ALC code into functional blocks to prompt
22
+ with for transcoding.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: None | BaseLanguageModel = None,
28
+ max_tokens: int = 4096,
29
+ protected_node_types: tuple[str, ...] = (),
30
+ prune_node_types: tuple[str, ...] = (),
31
+ prune_unprotected: bool = False,
32
+ ):
33
+ """Initialize a AlcSplitter instance.
34
+
35
+ Arguments:
36
+ max_tokens: The maximum number of tokens supported by the model
37
+ """
38
+ super().__init__(
39
+ language="ibmhlasm",
40
+ model=model,
41
+ max_tokens=max_tokens,
42
+ protected_node_types=protected_node_types,
43
+ prune_node_types=prune_node_types,
44
+ prune_unprotected=prune_unprotected,
45
+ )
46
+
47
+ def _get_ast(self, code: str) -> CodeBlock:
48
+ root = super()._get_ast(code)
49
+
50
+ # Current treesitter implementation does not nest csects and dsects
51
+ # The loop below nests nodes following csect/dsect instructions into
52
+ # the children of that instruction
53
+ sect_types = {"csect_instruction", "dsect_instruction"}
54
+ queue: list[CodeBlock] = [root]
55
+ while queue:
56
+ block = queue.pop(0)
57
+
58
+ # Search this children for csects and dsects. Create a list of groups
59
+ # where each group is a csect or dsect, starting with the csect/dsect
60
+ # instruction and containing all the subsequent nodes up until the
61
+ # next csect or dsect instruction
62
+ sects: list[list[CodeBlock]] = [[]]
63
+ for c in block.children:
64
+ if c.node_type in sect_types:
65
+ sects.append([c])
66
+ else:
67
+ sects[-1].append(c)
68
+
69
+ sects = [s for s in sects if s]
70
+
71
+ # Restructure the tree, making the head of each group the parent
72
+ # of all the remaining nodes in that group
73
+ if len(sects) > 1:
74
+ block.children = []
75
+ for sect in sects:
76
+ if sect[0].node_type in sect_types:
77
+ sect_node = self.merge_nodes(sect)
78
+ sect_node.children = sect
79
+ sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
80
+ block.children.append(sect_node)
81
+ else:
82
+ block.children.extend(sect)
83
+
84
+ # Push the children onto the queue
85
+ queue.extend(block.children)
86
+
87
+ return root
@@ -152,9 +152,11 @@ class CodeBlock:
152
152
  Returns:
153
153
  A string representation of the tree with this block as the root
154
154
  """
155
+ tokens = self.tokens
155
156
  identifier = self.id
156
157
  if self.text is None:
157
158
  identifier = f"({identifier})"
159
+ tokens = self.total_tokens
158
160
  elif not self.complete:
159
161
  identifier += "*"
160
162
  if self.start_point is not None and self.end_point is not None:
@@ -165,7 +167,7 @@ class CodeBlock:
165
167
  seg = ""
166
168
  return "\n".join(
167
169
  [
168
- f"{'| '*depth}{identifier}{seg}",
170
+ f"{'| '*depth}{identifier}{seg} ({tokens:,d} tokens)",
169
171
  *[c.tree_str(depth + 1) for c in self.children],
170
172
  ]
171
173
  )
@@ -48,6 +48,7 @@ class MumpsSplitter(Splitter):
48
48
  max_tokens: int = 4096,
49
49
  protected_node_types: tuple[str] = ("routine_definition",),
50
50
  prune_node_types: tuple[str] = (),
51
+ prune_unprotected: bool = False,
51
52
  ):
52
53
  """Initialize a MumpsSplitter instance.
53
54
 
@@ -60,11 +61,9 @@ class MumpsSplitter(Splitter):
60
61
  max_tokens=max_tokens,
61
62
  protected_node_types=protected_node_types,
62
63
  prune_node_types=prune_node_types,
64
+ prune_unprotected=prune_unprotected,
63
65
  )
64
66
 
65
- # MUMPS code tends to take about 2/3 the space of Python
66
- self.max_tokens: int = int(max_tokens * 2 / 5)
67
-
68
67
  def _set_identifiers(self, root: CodeBlock, name: str):
69
68
  stack = [root]
70
69
  while stack:
@@ -1,4 +1,4 @@
1
1
  from .basic_splitter import FileSplitter
2
2
  from .chunk_splitter import ChunkSplitter
3
- from .simple_ast import FlexibleTreeSitterSplitter, StrictTreeSitterSplitter
3
+ from .simple_ast import get_flexible_ast, get_strict_ast
4
4
  from .tag_splitter import TagSplitter
@@ -1,7 +1,7 @@
1
- from janus.language.block import CodeBlock
2
- from janus.language.naive.chunk_splitter import ChunkSplitter
3
- from janus.language.naive.registry import register_splitter
4
- from janus.language.splitter import FileSizeError
1
+ from ..block import CodeBlock
2
+ from ..naive.chunk_splitter import ChunkSplitter
3
+ from ..naive.registry import register_splitter
4
+ from ..splitter import FileSizeError
5
5
 
6
6
 
7
7
  @register_splitter("file")
@@ -1,7 +1,7 @@
1
- from janus.language.block import CodeBlock
2
- from janus.language.naive.registry import register_splitter
3
- from janus.language.node import NodeType
4
- from janus.language.splitter import Splitter
1
+ from ..block import CodeBlock
2
+ from ..node import NodeType
3
+ from ..splitter import Splitter
4
+ from .registry import register_splitter
5
5
 
6
6
 
7
7
  @register_splitter("chunk")
@@ -1,6 +1,6 @@
1
1
  from typing import Callable, Dict
2
2
 
3
- from janus.language.splitter import Splitter
3
+ from ..splitter import Splitter
4
4
 
5
5
  CUSTOM_SPLITTERS: Dict[str, Callable[..., Splitter]] = dict()
6
6
 
@@ -0,0 +1,29 @@
1
+ from ...utils.enums import LANGUAGES
2
+ from ..alc.alc import AlcSplitter
3
+ from ..mumps.mumps import MumpsSplitter
4
+ from ..treesitter import TreeSitterSplitter
5
+ from .registry import register_splitter
6
+
7
+
8
+ @register_splitter("ast-flex")
9
+ def get_flexible_ast(language: str, **kwargs):
10
+ if language == "ibmhlasm":
11
+ return AlcSplitter(**kwargs)
12
+ elif language == "mumps":
13
+ return MumpsSplitter(**kwargs)
14
+ else:
15
+ return TreeSitterSplitter(language=language, **kwargs)
16
+
17
+
18
+ @register_splitter("ast-strict")
19
+ def get_strict_ast(language: str, **kwargs):
20
+ kwargs.update(
21
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
22
+ prune_unprotected=True,
23
+ )
24
+ if language == "ibmhlasm":
25
+ return AlcSplitter(**kwargs)
26
+ elif language == "mumps":
27
+ return MumpsSplitter(**kwargs)
28
+ else:
29
+ return TreeSitterSplitter(language=language, **kwargs)
@@ -1,7 +1,7 @@
1
- from janus.language.block import CodeBlock
2
- from janus.language.naive.registry import register_splitter
3
- from janus.language.node import NodeType
4
- from janus.language.splitter import Splitter
1
+ from ..block import CodeBlock
2
+ from ..node import NodeType
3
+ from ..splitter import Splitter
4
+ from .registry import register_splitter
5
5
 
6
6
 
7
7
  @register_splitter("tag")
@@ -47,8 +47,8 @@ class Splitter(FileManager):
47
47
  model: None | BaseLanguageModel = None,
48
48
  max_tokens: int = 4096,
49
49
  skip_merge: bool = False,
50
- protected_node_types: tuple[str] = (),
51
- prune_node_types: tuple[str] = (),
50
+ protected_node_types: tuple[str, ...] = (),
51
+ prune_node_types: tuple[str, ...] = (),
52
52
  prune_unprotected: bool = False,
53
53
  ):
54
54
  """
@@ -340,7 +340,10 @@ class Splitter(FileManager):
340
340
  # Double check length (in theory this should never be an issue)
341
341
  tokens = self._count_tokens(text)
342
342
  if tokens > self.max_tokens:
343
- log.error(f"Merged node ({name}) too long for context!")
343
+ log.error(
344
+ f"Merged node ({name}) too long for context!"
345
+ f" ({tokens} > {self.max_tokens})"
346
+ )
344
347
 
345
348
  return CodeBlock(
346
349
  text=text,
@@ -420,7 +423,10 @@ class Splitter(FileManager):
420
423
  name = f"{node.name}-L#{node_line}"
421
424
  tokens = self._count_tokens(line)
422
425
  if tokens > self.max_tokens:
423
- raise TokenLimitError(r"Irreducible node too large for context!")
426
+ raise TokenLimitError(
427
+ "Irreducible node too large for context!"
428
+ f" ({tokens} > {self.max_tokens})"
429
+ )
424
430
 
425
431
  node.children.append(
426
432
  CodeBlock(
@@ -26,8 +26,8 @@ class TreeSitterSplitter(Splitter):
26
26
  language: str,
27
27
  model: None | BaseLanguageModel = None,
28
28
  max_tokens: int = 4096,
29
- protected_node_types: tuple[str] = (),
30
- prune_node_types: tuple[str] = (),
29
+ protected_node_types: tuple[str, ...] = (),
30
+ prune_node_types: tuple[str, ...] = (),
31
31
  prune_unprotected: bool = False,
32
32
  ) -> None:
33
33
  """Initialize a TreeSitterSplitter instance.
@@ -48,10 +48,10 @@ class TreeSitterSplitter(Splitter):
48
48
  self._load_parser()
49
49
 
50
50
  def _get_ast(self, code: str) -> CodeBlock:
51
- code = bytes(code, "utf-8")
52
- tree = self.parser.parse(code)
51
+ code_bytes = bytes(code, "utf-8")
52
+ tree = self.parser.parse(code_bytes)
53
53
  root = tree.walk().node
54
- root = self._node_to_block(root, code)
54
+ root = self._node_to_block(root, code_bytes)
55
55
  return root
56
56
 
57
57
  # Recursively print tree to view parsed output (dev helper function)
@@ -98,7 +98,7 @@ class TreeSitterSplitter(Splitter):
98
98
 
99
99
  text = node.text.decode()
100
100
  children = [self._node_to_block(child, original_text) for child in node.children]
101
- node = CodeBlock(
101
+ return CodeBlock(
102
102
  id=node.id,
103
103
  name=str(node.id),
104
104
  text=text,
@@ -112,7 +112,6 @@ class TreeSitterSplitter(Splitter):
112
112
  language=self.language,
113
113
  tokens=self._count_tokens(text),
114
114
  )
115
- return node
116
115
 
117
116
  def _load_parser(self) -> None:
118
117
  """Load the parser for the given language.
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
8
8
  from langchain_core.outputs import ChatGeneration, LLMResult
9
9
  from langchain_core.tracers.context import register_configure_hook
10
10
 
11
- from janus.utils.logger import create_logger
11
+ from ..utils.logger import create_logger
12
12
 
13
13
  log = create_logger(__name__)
14
14
 
@@ -8,8 +8,7 @@ from langchain_community.llms import HuggingFaceTextGenInference
8
8
  from langchain_core.language_models import BaseLanguageModel
9
9
  from langchain_openai import ChatOpenAI
10
10
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
- from janus.prompts.prompt import (
11
+ from ..prompts.prompt import (
13
12
  ChatGptPromptEngine,
14
13
  ClaudePromptEngine,
15
14
  CoherePromptEngine,
@@ -18,8 +17,8 @@ from janus.prompts.prompt import (
18
17
  PromptEngine,
19
18
  TitanPromptEngine,
20
19
  )
21
-
22
20
  from ..utils.logger import create_logger
21
+ from .model_callbacks import COST_PER_1K_TOKENS
23
22
 
24
23
  log = create_logger(__name__)
25
24
 
@@ -3,8 +3,7 @@ from unittest.mock import patch
3
3
 
4
4
  import pytest
5
5
 
6
- from janus.llm.models_info import load_model
7
-
6
+ from ...llm.models_info import load_model
8
7
  from ..llm_metrics import llm_evaluate_option, llm_evaluate_ref_option
9
8
 
10
9
 
@@ -40,7 +39,7 @@ class TestLLMMetrics(unittest.TestCase):
40
39
  print("'Hello, world!")
41
40
  """
42
41
 
43
- @patch("janus.llm.models_info.load_model")
42
+ @patch(".llm.models_info.load_model")
44
43
  @patch("janus.metrics.llm_metrics.llm_evaluate")
45
44
  @pytest.mark.llm_eval
46
45
  def test_llm_self_eval_quality(self, mock_llm_evaluate, mock_load_model):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.rouge_score import rouge
3
+ from ..rouge_score import rouge
4
4
 
5
5
 
6
6
  class TestRouge(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.similarity import similarity_score
3
+ from ..similarity import similarity_score
4
4
 
5
5
 
6
6
  class TestSimilarityScore(unittest.TestCase):
@@ -1,10 +1,9 @@
1
1
  import math
2
2
  from typing import List, Optional
3
3
 
4
- from janus.language.block import CodeBlock
5
- from janus.language.treesitter.treesitter import TreeSitterSplitter
6
- from janus.utils.enums import LANGUAGES
7
-
4
+ from ..language.block import CodeBlock
5
+ from ..language.treesitter.treesitter import TreeSitterSplitter
6
+ from ..utils.enums import LANGUAGES
8
7
  from .metric import metric
9
8
 
10
9
 
@@ -7,10 +7,9 @@ import click
7
7
  import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
- from janus.llm import load_model
11
- from janus.utils.enums import LANGUAGES
12
- from janus.utils.logger import create_logger
13
-
10
+ from ..llm import load_model
11
+ from ..utils.enums import LANGUAGES
12
+ from ..utils.logger import create_logger
14
13
  from ..utils.progress import track
15
14
  from .cli import evaluate
16
15
  from .file_pairing import FILE_PAIRING_METHODS
@@ -34,6 +34,40 @@ HUMAN_PROMPT_TEMPLATE_FILENAME = "human.txt"
34
34
  PROMPT_VARIABLES_FILENAME = "variables.json"
35
35
 
36
36
 
37
+ retry_with_output_prompt_text = """Instructions:
38
+ --------------
39
+ {instructions}
40
+ --------------
41
+ Completion:
42
+ --------------
43
+ {input}
44
+ --------------
45
+
46
+ Above, the Completion did not satisfy the constraints given in the Instructions.
47
+ Error:
48
+ --------------
49
+ {error}
50
+ --------------
51
+
52
+ Please try again. Please only respond with an answer that satisfies the
53
+ constraints laid out in the Instructions:"""
54
+
55
+
56
+ retry_with_error_and_output_prompt_text = """Prompt:
57
+ {prompt}
58
+ Completion:
59
+ {input}
60
+
61
+ Above, the Completion did not satisfy the constraints given in the Prompt.
62
+ Details: {error}
63
+ Please try again:"""
64
+
65
+ retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
66
+ retry_with_error_and_output_prompt = PromptTemplate.from_template(
67
+ retry_with_error_and_output_prompt_text
68
+ )
69
+
70
+
37
71
  class PromptEngine(ABC):
38
72
  """A class defining prompting schemes for the LLM."""
39
73
 
@@ -16,12 +16,11 @@ from langchain_core.runnables import RunnableLambda, RunnableParallel
16
16
  from openai import BadRequestError, RateLimitError
17
17
  from text_generation.errors import ValidationError
18
18
 
19
- from janus.language.naive.registry import CUSTOM_SPLITTERS
20
-
21
19
  from .converter import Converter, run_if_changed
22
20
  from .embedding.vectorize import ChromaDBVectorizer
23
21
  from .language.block import CodeBlock, TranslatedCodeBlock
24
22
  from .language.combine import ChunkCombiner, Combiner, JsonCombiner
23
+ from .language.naive.registry import CUSTOM_SPLITTERS
25
24
  from .language.splitter import EmptyTreeError, FileSizeError, TokenLimitError
26
25
  from .llm import load_model
27
26
  from .llm.model_callbacks import get_model_callback
@@ -30,7 +29,12 @@ from .parsers.code_parser import CodeParser, GenericParser
30
29
  from .parsers.doc_parser import MadlibsDocumentationParser, MultiDocumentationParser
31
30
  from .parsers.eval_parser import EvaluationParser
32
31
  from .parsers.reqs_parser import RequirementsParser
33
- from .prompts.prompt import SAME_OUTPUT, TEXT_OUTPUT
32
+ from .prompts.prompt import (
33
+ SAME_OUTPUT,
34
+ TEXT_OUTPUT,
35
+ retry_with_error_and_output_prompt,
36
+ retry_with_output_prompt,
37
+ )
34
38
  from .utils.enums import LANGUAGES
35
39
  from .utils.logger import create_logger
36
40
 
@@ -407,10 +411,10 @@ class Translator(Converter):
407
411
  """
408
412
  self._parser.set_reference(block.original)
409
413
 
410
- # Retries with just the output and the error
414
+ # Retries with just the format instructions, the output, and the error
411
415
  n1 = round(self.max_prompts ** (1 / 3))
412
416
 
413
- # Retries with the input, output, and error
417
+ # Retries with the input, the output, and the error
414
418
  n2 = round((self.max_prompts // n1) ** (1 / 2))
415
419
 
416
420
  # Retries with just the input
@@ -420,11 +424,13 @@ class Translator(Converter):
420
424
  llm=self._llm,
421
425
  parser=self._parser,
422
426
  max_retries=n1,
427
+ prompt=retry_with_output_prompt,
423
428
  )
424
429
  retry = RetryWithErrorOutputParser.from_llm(
425
430
  llm=self._llm,
426
431
  parser=fix_format,
427
432
  max_retries=n2,
433
+ prompt=retry_with_error_and_output_prompt,
428
434
  )
429
435
 
430
436
  completion_chain = self._prompt | self._llm
File without changes
@@ -10,7 +10,7 @@ class EmbeddingType(Enum):
10
10
  TARGET = 5 # placeholder embeddings, are these useful for analysis?
11
11
 
12
12
 
13
- CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary"}
13
+ CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary", "ibmhlasm"}
14
14
 
15
15
  LANGUAGES: Dict[str, Dict[str, Any]] = {
16
16
  "ada": {
@@ -63,7 +63,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
63
63
  '#include <stdio.h>\n\nint main() {\n printf("Hello, World!\\n");\n'
64
64
  " return 0;\n}\n"
65
65
  ),
66
- "functional_node_type": "function_definition",
66
+ "functional_node_types": ["function_definition"],
67
67
  "comment_node_type": "comment",
68
68
  },
69
69
  "capnp": {
@@ -206,7 +206,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
206
206
  "example": (
207
207
  "program HelloWorld\n print *, 'Hello, World!'\nend program HelloWorld\n"
208
208
  ),
209
- "functional_node_type": "function",
209
+ "functional_node_types": ["function"],
210
210
  "comment_node_type": "comment",
211
211
  },
212
212
  "gitattributes": {
@@ -300,6 +300,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
300
300
  END HELLO
301
301
  """
302
302
  ),
303
+ "functional_node_types": ["csect", "dsect"],
303
304
  "branch_node_types": ["branch_instruction"],
304
305
  "operation_node_types": ["operation", "branch_operation"],
305
306
  "operand_node_types": ["operands"],
@@ -420,7 +421,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
420
421
  "suffix": "m",
421
422
  "url": "https://github.com/janus-llm/tree-sitter-mumps",
422
423
  "example": 'WRITE "Hello, World!"',
423
- "functional_node_type": "routine_definition",
424
+ "functional_node_types": ["routine_definition"],
424
425
  "comment_node_type": "comment",
425
426
  "branch_node_types": ["if_statement"],
426
427
  "operation_node_types": [
@@ -512,7 +513,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
512
513
  "suffix": "py",
513
514
  "url": "https://github.com/tree-sitter/tree-sitter-python",
514
515
  "example": "# Hello, World!\nprint('Hello, World!')\n",
515
- "functional_node_type": "function_definition",
516
+ "functional_node_types": ["function_definition"],
516
517
  "comment_node_type": "comment",
517
518
  },
518
519
  "qmljs": {
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "janus-llm"
3
- version = "2.0.2"
3
+ version = "2.1.0"
4
4
  description = "A transcoding library using LLMs."
5
5
  authors = ["Michael Doyle <mdoyle@mitre.org>", "Chris Glasz <cglasz@mitre.org>",
6
6
  "Chris Tohline <ctohline@mitre.org>", "William Macke <wmacke@mitre.org>",
@@ -1,18 +0,0 @@
1
- from janus.language.naive.registry import register_splitter
2
- from janus.language.treesitter import TreeSitterSplitter
3
- from janus.utils.enums import LANGUAGES
4
-
5
-
6
- @register_splitter("ast-flex")
7
- class FlexibleTreeSitterSplitter(TreeSitterSplitter):
8
- pass
9
-
10
-
11
- @register_splitter("ast-strict")
12
- class StrictTreeSitterSplitter(TreeSitterSplitter):
13
- def __init__(self, language: str, **kwargs):
14
- kwargs.update(
15
- protected_node_types=(LANGUAGES[language]["functional_node_type"],),
16
- prune_unprotected=True,
17
- )
18
- super().__init__(language=language, **kwargs)
File without changes
File without changes
File without changes