janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +120 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +9 -6
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +134 -70
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
  70. janus_llm-2.0.0.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,62 @@
1
+ import unittest
2
+
3
+ from ..combine import CodeBlock, Combiner, TranslatedCodeBlock
4
+
5
+
6
+ class TestCombiner(unittest.TestCase):
7
+ def setUp(self):
8
+ self.combiner = Combiner()
9
+ self.block = CodeBlock(
10
+ id=1,
11
+ name="test",
12
+ node_type="test",
13
+ language="python",
14
+ text="# test",
15
+ start_point=(0, 0),
16
+ end_point=(0, 0),
17
+ start_byte=0,
18
+ end_byte=0,
19
+ tokens=[],
20
+ children=[
21
+ CodeBlock(
22
+ id=2,
23
+ name="child",
24
+ node_type="test",
25
+ language="python",
26
+ text="test",
27
+ start_point=(0, 0),
28
+ end_point=(0, 0),
29
+ start_byte=0,
30
+ end_byte=0,
31
+ tokens=[],
32
+ children=[],
33
+ )
34
+ ],
35
+ )
36
+ self.translated_block = TranslatedCodeBlock(
37
+ self.block,
38
+ language="python",
39
+ )
40
+
41
+ def test_combine(self):
42
+ self.combiner.combine(self.block)
43
+ self.assertFalse(self.block.omit_prefix)
44
+
45
+ def test_combine_children(self):
46
+ self.block.complete = False
47
+ self.combiner.combine_children(self.block)
48
+ self.assertTrue(self.block.complete)
49
+
50
+ def test_combine_children_with_translated_block(self):
51
+ self.translated_block.complete = False
52
+ self.combiner.combine_children(self.translated_block)
53
+ self.assertFalse(self.translated_block.complete)
54
+
55
+ def test_combine_children_with_text_none(self):
56
+ self.combiner.combine_children(self.block)
57
+ self.assertEqual(self.block.text, "# test")
58
+ self.assertTrue(self.block.complete)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ unittest.main()
@@ -0,0 +1,16 @@
1
+ import unittest
2
+
3
+ from janus.language.splitter import Splitter
4
+
5
+
6
+ class TestSplitter(unittest.TestCase):
7
+ def setUp(self):
8
+ self.splitter = Splitter(language="python")
9
+
10
+ def test_split(self):
11
+ input_data = "janus/__main__.py"
12
+ self.assertRaises(NotImplementedError, self.splitter.split, input_data)
13
+
14
+
15
+ if __name__ == "__main__":
16
+ unittest.main()
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import unittest
3
3
  from pathlib import Path
4
+ from unittest.mock import patch
4
5
 
5
6
  import pytest
6
7
 
@@ -12,12 +13,26 @@ class TestBinarySplitter(unittest.TestCase):
12
13
  """Tests for the BinarySplitter class."""
13
14
 
14
15
  def setUp(self):
15
- model_name = "gpt-3.5-turbo"
16
+ model_name = "gpt-3.5-turbo-0125"
16
17
  self.binary_file = Path("janus/language/binary/_tests/hello")
17
18
  self.llm, _, _ = load_model(model_name)
18
19
  self.splitter = BinarySplitter(model=self.llm)
19
20
  os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
20
21
 
22
+ def test_setup(self):
23
+ """Test that the setup sets the environment variable correctly."""
24
+ with patch("os.getenv") as mock_getenv:
25
+ mock_getenv.return_value = "~/programs/ghidra_10.4_PUBLIC"
26
+ self.assertEqual(
27
+ os.getenv("GHIDRA_INSTALL_PATH"), "~/programs/ghidra_10.4_PUBLIC"
28
+ )
29
+ mock_getenv.assert_called_once_with("GHIDRA_INSTALL_PATH")
30
+
31
+ def test_initialization(self):
32
+ """Test that BinarySplitter is initialized correctly."""
33
+ self.assertIsInstance(self.splitter, BinarySplitter)
34
+ self.assertEqual(self.splitter.model, self.llm)
35
+
21
36
  @pytest.mark.ghidra(
22
37
  reason=(
23
38
  "No way to test this in CI w/o installing Ghidra, but want to keep here to "
@@ -29,7 +29,13 @@ class BinarySplitter(TreeSitterSplitter):
29
29
  with for transcoding.
30
30
  """
31
31
 
32
- def __init__(self, model: None | BaseLanguageModel = None, max_tokens: int = 4096):
32
+ def __init__(
33
+ self,
34
+ model: None | BaseLanguageModel = None,
35
+ max_tokens: int = 4096,
36
+ protected_node_types: tuple[str] = (),
37
+ prune_node_types: tuple[str] = (),
38
+ ):
33
39
  """Initialize a BinarySplitter instance.
34
40
 
35
41
  Arguments:
@@ -40,7 +46,8 @@ class BinarySplitter(TreeSitterSplitter):
40
46
  language="binary",
41
47
  model=model,
42
48
  max_tokens=max_tokens,
43
- use_placeholders=False,
49
+ protected_node_types=protected_node_types,
50
+ prune_node_types=prune_node_types,
44
51
  )
45
52
 
46
53
  def _execute_ghidra_script(self, cmd: list[str]) -> str:
@@ -131,7 +138,7 @@ class BinarySplitter(TreeSitterSplitter):
131
138
  code = self._get_decompilation(file)
132
139
 
133
140
  root = self._get_ast(code)
134
- self._set_identifiers(root, path)
141
+ self._set_identifiers(root, path.name)
135
142
  self._segment_leaves(root)
136
143
  self._merge_tree(root)
137
144
 
janus/language/block.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from functools import total_ordering
2
- from typing import ForwardRef, Hashable
2
+ from typing import ForwardRef, Hashable, Optional, Tuple
3
3
 
4
4
  from ..utils.logger import create_logger
5
5
  from .node import NodeType
@@ -14,7 +14,7 @@ class CodeBlock:
14
14
  Attributes:
15
15
  id: The id of the code block in the AST
16
16
  name: Descriptive name of node
17
- type: The type of the code block ('function', 'module', etc.). Defined in the
17
+ node_type: The type of the code block ('function', 'module', etc.). Defined in the
18
18
  language-specific modules.
19
19
  language: The language of the code block.
20
20
  text: The code block.
@@ -33,32 +33,32 @@ class CodeBlock:
33
33
  def __init__(
34
34
  self,
35
35
  id: Hashable,
36
- name: None | str,
37
- type: NodeType,
36
+ name: Optional[str],
37
+ node_type: NodeType,
38
38
  language: str,
39
- text: None | str,
40
- start_point: None | tuple[int, int],
41
- end_point: None | tuple[int, int],
42
- start_byte: None | int,
43
- end_byte: None | int,
39
+ text: Optional[str],
40
+ start_point: Optional[Tuple[int, int]],
41
+ end_point: Optional[Tuple[int, int]],
42
+ start_byte: Optional[int],
43
+ end_byte: Optional[int],
44
44
  tokens: int,
45
45
  children: list[ForwardRef("CodeBlock")],
46
- embedding_id: None | str = None,
47
- affixes: tuple[str, str] = ("", ""),
46
+ embedding_id: Optional[str] = None,
47
+ affixes: Tuple[str, str] = ("", ""),
48
48
  ) -> None:
49
49
  self.id: Hashable = id
50
- self.name: None | str = name
51
- self.type: NodeType = type
50
+ self.name: Optional[str] = name
51
+ self.node_type: NodeType = node_type
52
52
  self.language: str = language
53
- self.text: None | str = text
54
- self.start_point: None | tuple[int, int] = start_point
55
- self.end_point: None | tuple[int, int] = end_point
56
- self.start_byte: None | [int] = start_byte
57
- self.end_byte: None | [int] = end_byte
53
+ self.text: Optional[str] = text
54
+ self.start_point: Optional[Tuple[int, int]] = start_point
55
+ self.end_point: Optional[Tuple[int, int]] = end_point
56
+ self.start_byte: Optional[int] = start_byte
57
+ self.end_byte: Optional[int] = end_byte
58
58
  self.tokens: int = tokens
59
59
  self.children: list[ForwardRef("CodeBlock")] = sorted(children)
60
- self.embedding_id: None | [str] = embedding_id
61
- self.affixes: tuple[str, str] = affixes
60
+ self.embedding_id: Optional[str] = embedding_id
61
+ self.affixes: Tuple[str, str] = affixes
62
62
 
63
63
  self.complete = True
64
64
  self.omit_prefix = True
@@ -83,15 +83,7 @@ class CodeBlock:
83
83
 
84
84
  @property
85
85
  def complete_text(self) -> str:
86
- return f"{self.prefix}{self.text}{self.suffix}"
87
-
88
- @property
89
- def placeholder(self) -> str:
90
- return f"<<<{self.id}>>>"
91
-
92
- @property
93
- def complete_placeholder(self) -> str:
94
- return f"{self.prefix}<<<{self.id}>>>{self.suffix}"
86
+ return f"{self.prefix}{self.text or ''}{self.suffix}"
95
87
 
96
88
  @property
97
89
  def n_descendents(self) -> int:
@@ -146,6 +138,14 @@ class CodeBlock:
146
138
  self.affixes = (self.affixes[0], "")
147
139
  return suffix
148
140
 
141
+ def rebuild_text_from_children(self):
142
+ if self.children:
143
+ prefix = self.affixes[0] + self.children[0].pop_prefix()
144
+ suffix = self.children[-1].pop_suffix() + self.affixes[1]
145
+ self.text = "".join(c.complete_text for c in self.children)
146
+ self.affixes = (prefix, suffix)
147
+ self.tokens = sum(c.tokens for c in self.children)
148
+
149
149
  def tree_str(self, depth: int = 0) -> str:
150
150
  """A string representation of the tree with this block as the root
151
151
 
@@ -195,7 +195,7 @@ class TranslatedCodeBlock(CodeBlock):
195
195
  super().__init__(
196
196
  id=original.id,
197
197
  name=original.name,
198
- type=original.type,
198
+ node_type=original.node_type,
199
199
  language=language,
200
200
  text=None,
201
201
  start_point=original.start_point,
@@ -214,6 +214,7 @@ class TranslatedCodeBlock(CodeBlock):
214
214
  self.translated = False
215
215
  self.cost = 0.0
216
216
  self.retries = 0
217
+ self.processing_time = 0
217
218
 
218
219
  @property
219
220
  def total_cost(self) -> float:
janus/language/combine.py CHANGED
@@ -11,14 +11,14 @@ class Combiner(FileManager):
11
11
  """
12
12
 
13
13
  @staticmethod
14
- def combine(block: CodeBlock) -> None:
14
+ def combine(root: CodeBlock) -> None:
15
15
  """Combine the given block with its children.
16
16
 
17
17
  Arguments:
18
- block: The functional code block to combine with its children.
18
+ root: The functional code block to combine with its children.
19
19
  """
20
- Combiner.combine_children(block)
21
- block.omit_prefix = False
20
+ Combiner.combine_children(root)
21
+ root.omit_prefix = False
22
22
 
23
23
  @staticmethod
24
24
  def combine_children(block: CodeBlock) -> None:
@@ -48,16 +48,11 @@ class Combiner(FileManager):
48
48
  block.complete = children_complete
49
49
  return
50
50
 
51
- # Replace all placeholders
52
51
  missing_children = []
53
52
  for child in block.children:
54
53
  if isinstance(block, TranslatedCodeBlock) and not child.translated:
55
54
  missing_children.append(child)
56
55
  continue
57
- if not Combiner.contains_child(block.text, child):
58
- missing_children.append(child)
59
- continue
60
- block.text = block.text.replace(child.placeholder, child.text)
61
56
 
62
57
  if missing_children:
63
58
  missing_ids = [c.id for c in missing_children]
@@ -66,36 +61,33 @@ class Combiner(FileManager):
66
61
  block.children = missing_children
67
62
  block.complete = children_complete and not missing_children
68
63
 
64
+
65
+ class JsonCombiner(Combiner):
69
66
  @staticmethod
70
- def contains_child(code: str, child: CodeBlock) -> bool:
71
- """Determine whether the given code contains a placeholder for the given
72
- child block.
67
+ def combine(root: CodeBlock) -> None:
68
+ """Combine the given block with its children.
73
69
 
74
70
  Arguments:
75
- code: The code to check for the placeholder
76
- child: The child block to check for
77
-
78
- Returns:
79
- Whether the given code contains a placeholder for the given child
80
- block.
71
+ root: The functional code block to combine with its children.
81
72
  """
82
- return code is None or child.placeholder in code
83
-
73
+ stack = [root]
74
+ while stack:
75
+ block = stack.pop()
76
+ if block.children:
77
+ stack.extend(block.children)
78
+ block.affixes = ("", "")
79
+ else:
80
+ block.affixes = ("\n", "\n")
81
+ super(JsonCombiner, JsonCombiner).combine(root)
82
+
83
+
84
+ class ChunkCombiner(Combiner):
84
85
  @staticmethod
85
- def count_missing(input_block: CodeBlock, output_code: str) -> int:
86
- """Return the number of children of input_block who are not represented
87
- in output_code with a placeholder
86
+ def combine(root: CodeBlock) -> None:
87
+ """A combiner which doesn't actually combine the code blocks,
88
+ instead preserving children
88
89
 
89
90
  Arguments:
90
- input_block: The block to check for missing children
91
- output_code: The code to check for placeholders
92
-
93
- Returns:
94
- The number of children of input_block who are not represented in
95
- output_code with a placeholder
91
+ root: The functional code block to combine with its children.
96
92
  """
97
- missing_children = 0
98
- for child in input_block.children:
99
- if not Combiner.contains_child(output_code, child):
100
- missing_children += 1
101
- return missing_children
93
+ return root
@@ -11,7 +11,7 @@ class TestMumpsSplitter(unittest.TestCase):
11
11
 
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
- model_name = "gpt-3.5-turbo"
14
+ model_name = "gpt-3.5-turbo-0125"
15
15
  llm, _, _ = load_model(model_name)
16
16
  self.splitter = MumpsSplitter(model=llm)
17
17
  self.combiner = Combiner(language="mumps")
@@ -20,7 +20,7 @@ class TestMumpsSplitter(unittest.TestCase):
20
20
  def test_split(self):
21
21
  """Test the split method."""
22
22
  tree_root = self.splitter.split(self.test_file)
23
- self.assertEqual(len(tree_root.children), 6)
23
+ self.assertEqual(len(tree_root.children), 22)
24
24
  self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
25
  self.assertFalse(tree_root.complete)
26
26
  self.combiner.combine_children(tree_root)
@@ -1,5 +1,4 @@
1
1
  import re
2
- from pathlib import Path
3
2
 
4
3
  from langchain.schema.language_model import BaseLanguageModel
5
4
 
@@ -43,7 +42,13 @@ class MumpsSplitter(Splitter):
43
42
  re.VERBOSE | re.DOTALL,
44
43
  )
45
44
 
46
- def __init__(self, model: None | BaseLanguageModel = None, max_tokens: int = 4096):
45
+ def __init__(
46
+ self,
47
+ model: None | BaseLanguageModel = None,
48
+ max_tokens: int = 4096,
49
+ protected_node_types: tuple[str] = ("routine_definition",),
50
+ prune_node_types: tuple[str] = (),
51
+ ):
47
52
  """Initialize a MumpsSplitter instance.
48
53
 
49
54
  Arguments:
@@ -53,17 +58,18 @@ class MumpsSplitter(Splitter):
53
58
  language="mumps",
54
59
  model=model,
55
60
  max_tokens=max_tokens,
56
- use_placeholders=False,
61
+ protected_node_types=protected_node_types,
62
+ prune_node_types=prune_node_types,
57
63
  )
58
64
 
59
65
  # MUMPS code tends to take about 2/3 the space of Python
60
66
  self.max_tokens: int = int(max_tokens * 2 / 5)
61
67
 
62
- def _set_identifiers(self, root: CodeBlock, path: Path):
68
+ def _set_identifiers(self, root: CodeBlock, name: str):
63
69
  stack = [root]
64
70
  while stack:
65
71
  node = stack.pop()
66
- node.name = f"{path.name}:{node.id}"
72
+ node.name = f"{name}:{node.id}"
67
73
  stack.extend(node.children)
68
74
 
69
75
  def _get_ast(self, code: str) -> CodeBlock:
@@ -104,15 +110,19 @@ class MumpsSplitter(Splitter):
104
110
  start_byte=start_byte,
105
111
  end_byte=end_byte,
106
112
  affixes=(prefix, suffix),
107
- type=NodeType("subroutine"),
113
+ node_type=NodeType("routine_definition"),
108
114
  children=[],
109
115
  language=self.language,
110
116
  tokens=self._count_tokens(chunk),
111
117
  )
118
+ self._split_into_lines(node)
119
+ for line_node in node.children:
120
+ self._split_comment(line_node)
121
+
112
122
  children.append(node)
113
123
 
114
- start_byte = end_byte + len(bytes(suffix, "utf-8"))
115
- start_line = end_line + suffix.count("\n")
124
+ start_byte = end_byte
125
+ start_line = end_line
116
126
 
117
127
  return CodeBlock(
118
128
  text=code,
@@ -122,8 +132,82 @@ class MumpsSplitter(Splitter):
122
132
  end_point=(code.count("\n"), 0),
123
133
  start_byte=0,
124
134
  end_byte=len(bytes(code, "utf-8")),
125
- type=NodeType("routine"),
135
+ node_type=NodeType("routine"),
126
136
  children=children,
127
137
  language=self.language,
128
138
  tokens=self._count_tokens(code),
129
139
  )
140
+
141
+ @staticmethod
142
+ def comment_start(line: str) -> int:
143
+ first_semicolon = line.find(";")
144
+ if first_semicolon < 0:
145
+ return first_semicolon
146
+
147
+ # In mumps, quotes are escaped by doubling them (""). Single quote
148
+ # characters are logical not operators, not quotes
149
+ n_quotes = line[:first_semicolon].replace('""', "").count('"')
150
+
151
+ # If the number of quotes prior to the first semicolon is even, then
152
+ # that semicolon is not part of a quote (and therefore starts a comment)
153
+ if n_quotes % 2 == 0:
154
+ return first_semicolon
155
+
156
+ last_semicolon = first_semicolon
157
+ while (next_semicolon := line.find(";", last_semicolon + 1)) > 0:
158
+ n_quotes = line[last_semicolon:next_semicolon].replace('""', "").count('"')
159
+
160
+ # If the number of quotes in this chunk is odd, the total number
161
+ # of them up to this point is even, and the next semicolon begins
162
+ # the comment
163
+ if n_quotes % 2:
164
+ return next_semicolon
165
+
166
+ last_semicolon = next_semicolon
167
+
168
+ return -1
169
+
170
+ def _split_comment(self, line_node: CodeBlock):
171
+ comment_start = self.comment_start(line_node.text)
172
+ if comment_start < 0:
173
+ line_node.node_type = NodeType("code_line")
174
+ return
175
+
176
+ code = line_node.text[:comment_start]
177
+ if not code.strip():
178
+ line_node.node_type = NodeType("comment")
179
+ return
180
+
181
+ comment = line_node.text[comment_start:]
182
+ (l0, c0), (l1, c1) = line_node.start_point, line_node.end_point
183
+ prefix, suffix = line_node.affixes
184
+ code_bytes = len(bytes(code, "utf-8"))
185
+
186
+ line_node.children = [
187
+ CodeBlock(
188
+ text=code,
189
+ name=f"{line_node.name}-code",
190
+ id=f"{line_node.name}-code",
191
+ start_point=(l0, c0),
192
+ end_point=(l1, comment_start),
193
+ start_byte=line_node.start_byte,
194
+ end_byte=line_node.start_byte + code_bytes,
195
+ node_type=NodeType("code_line"),
196
+ children=[],
197
+ language=line_node.language,
198
+ tokens=self._count_tokens(code),
199
+ ),
200
+ CodeBlock(
201
+ text=comment,
202
+ name=f"{line_node.name}-comment",
203
+ id=f"{line_node.name}-comment",
204
+ start_point=(l0, c0 + comment_start),
205
+ end_point=(l1, c1),
206
+ start_byte=line_node.start_byte + code_bytes,
207
+ end_byte=line_node.end_byte,
208
+ node_type=NodeType("comment"),
209
+ children=[],
210
+ language=self.language,
211
+ tokens=self._count_tokens(comment),
212
+ ),
213
+ ]
@@ -0,0 +1,4 @@
1
+ from .basic_splitter import FileSplitter
2
+ from .chunk_splitter import ChunkSplitter
3
+ from .simple_ast import FlexibleTreeSitterSplitter, StrictTreeSitterSplitter
4
+ from .tag_splitter import TagSplitter
@@ -0,0 +1,14 @@
1
+ from janus.language.block import CodeBlock
2
+ from janus.language.naive.chunk_splitter import ChunkSplitter
3
+ from janus.language.naive.registry import register_splitter
4
+ from janus.language.splitter import FileSizeError
5
+
6
+
7
+ @register_splitter("file")
8
+ class FileSplitter(ChunkSplitter):
9
+ """
10
+ Splits based on the entire file of the code
11
+ """
12
+
13
+ def _split_into_lines(self, node: CodeBlock):
14
+ raise FileSizeError("File too large for basic splitter")
@@ -0,0 +1,26 @@
1
+ from janus.language.block import CodeBlock
2
+ from janus.language.naive.registry import register_splitter
3
+ from janus.language.node import NodeType
4
+ from janus.language.splitter import Splitter
5
+
6
+
7
+ @register_splitter("chunk")
8
+ class ChunkSplitter(Splitter):
9
+ """
10
+ Splits into fixed chunk sizes without parsing
11
+ """
12
+
13
+ def _get_ast(self, code: str) -> CodeBlock:
14
+ return CodeBlock(
15
+ text=code,
16
+ name="root",
17
+ id="root",
18
+ start_point=(0, 0),
19
+ end_point=(code.count("\n"), 0),
20
+ start_byte=0,
21
+ end_byte=len(bytes(code, "utf-8")),
22
+ node_type=NodeType("program"),
23
+ children=[],
24
+ language=self.language,
25
+ tokens=self._count_tokens(code),
26
+ )
@@ -0,0 +1,13 @@
1
+ from typing import Callable, Dict
2
+
3
+ from janus.language.splitter import Splitter
4
+
5
+ CUSTOM_SPLITTERS: Dict[str, Callable[..., Splitter]] = dict()
6
+
7
+
8
+ def register_splitter(name: str):
9
+ def callback(splitter):
10
+ CUSTOM_SPLITTERS[name] = splitter
11
+ return splitter
12
+
13
+ return callback
@@ -0,0 +1,18 @@
1
+ from janus.language.naive.registry import register_splitter
2
+ from janus.language.treesitter import TreeSitterSplitter
3
+ from janus.utils.enums import LANGUAGES
4
+
5
+
6
+ @register_splitter("ast-flex")
7
+ class FlexibleTreeSitterSplitter(TreeSitterSplitter):
8
+ pass
9
+
10
+
11
+ @register_splitter("ast-strict")
12
+ class StrictTreeSitterSplitter(TreeSitterSplitter):
13
+ def __init__(self, language: str, **kwargs):
14
+ kwargs.update(
15
+ protected_node_types=(LANGUAGES[language]["functional_node_type"],),
16
+ prune_unprotected=True,
17
+ )
18
+ super().__init__(language=language, **kwargs)
@@ -0,0 +1,61 @@
1
+ from janus.language.block import CodeBlock
2
+ from janus.language.naive.registry import register_splitter
3
+ from janus.language.node import NodeType
4
+ from janus.language.splitter import Splitter
5
+
6
+
7
+ @register_splitter("tag")
8
+ class TagSplitter(Splitter):
9
+ """
10
+ Splits code by tags inserted into code
11
+ """
12
+
13
+ def __init__(self, tag: str, *args, **kwargs):
14
+ kwargs.update(protected_node_types=("chunk",))
15
+ super().__init__(*args, **kwargs)
16
+ self._tag = f"\n{tag}\n"
17
+
18
+ def _get_ast(self, code: str) -> CodeBlock:
19
+ chunks = code.split(self._tag)
20
+ children = []
21
+ start_line = 0
22
+ start_byte = 0
23
+ for i, chunk in enumerate(chunks):
24
+ prefix = suffix = self._tag
25
+ if i == 0:
26
+ prefix = ""
27
+ if i == len(chunks) - 1:
28
+ suffix = ""
29
+ end_byte = start_byte + len(bytes(chunk, "utf-8"))
30
+ end_line = start_line + chunk.count("\n")
31
+ end_char = len(chunk) - chunk.rfind("\n") - 1
32
+ node = CodeBlock(
33
+ text=chunk,
34
+ name=f"Chunk {i}",
35
+ id=f"Chunk {i}",
36
+ start_point=(start_line, 0),
37
+ end_point=(end_line, end_char),
38
+ start_byte=start_byte,
39
+ end_byte=end_byte,
40
+ affixes=(prefix, suffix),
41
+ node_type=NodeType("chunk"),
42
+ children=[],
43
+ language=self.language,
44
+ tokens=self._count_tokens(chunk),
45
+ )
46
+ children.append(node)
47
+ start_line = end_line
48
+ start_byte = end_byte
49
+ return CodeBlock(
50
+ text=code,
51
+ name="root",
52
+ id="root",
53
+ start_point=(0, 0),
54
+ end_point=(code.count("\n"), 0),
55
+ start_byte=0,
56
+ end_byte=len(bytes(code, "utf-8")),
57
+ node_type=NodeType("program"),
58
+ children=children,
59
+ language=self.language,
60
+ tokens=self._count_tokens(code),
61
+ )