janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +120 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +9 -6
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +134 -70
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
  70. janus_llm-2.0.0.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,62 @@
1
+ import unittest
2
+
3
+ from ..combine import CodeBlock, Combiner, TranslatedCodeBlock
4
+
5
+
6
+ class TestCombiner(unittest.TestCase):
7
+ def setUp(self):
8
+ self.combiner = Combiner()
9
+ self.block = CodeBlock(
10
+ id=1,
11
+ name="test",
12
+ node_type="test",
13
+ language="python",
14
+ text="# test",
15
+ start_point=(0, 0),
16
+ end_point=(0, 0),
17
+ start_byte=0,
18
+ end_byte=0,
19
+ tokens=[],
20
+ children=[
21
+ CodeBlock(
22
+ id=2,
23
+ name="child",
24
+ node_type="test",
25
+ language="python",
26
+ text="test",
27
+ start_point=(0, 0),
28
+ end_point=(0, 0),
29
+ start_byte=0,
30
+ end_byte=0,
31
+ tokens=[],
32
+ children=[],
33
+ )
34
+ ],
35
+ )
36
+ self.translated_block = TranslatedCodeBlock(
37
+ self.block,
38
+ language="python",
39
+ )
40
+
41
+ def test_combine(self):
42
+ self.combiner.combine(self.block)
43
+ self.assertFalse(self.block.omit_prefix)
44
+
45
+ def test_combine_children(self):
46
+ self.block.complete = False
47
+ self.combiner.combine_children(self.block)
48
+ self.assertTrue(self.block.complete)
49
+
50
+ def test_combine_children_with_translated_block(self):
51
+ self.translated_block.complete = False
52
+ self.combiner.combine_children(self.translated_block)
53
+ self.assertFalse(self.translated_block.complete)
54
+
55
+ def test_combine_children_with_text_none(self):
56
+ self.combiner.combine_children(self.block)
57
+ self.assertEqual(self.block.text, "# test")
58
+ self.assertTrue(self.block.complete)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ unittest.main()
@@ -0,0 +1,16 @@
1
+ import unittest
2
+
3
+ from janus.language.splitter import Splitter
4
+
5
+
6
+ class TestSplitter(unittest.TestCase):
7
+ def setUp(self):
8
+ self.splitter = Splitter(language="python")
9
+
10
+ def test_split(self):
11
+ input_data = "janus/__main__.py"
12
+ self.assertRaises(NotImplementedError, self.splitter.split, input_data)
13
+
14
+
15
+ if __name__ == "__main__":
16
+ unittest.main()
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import unittest
3
3
  from pathlib import Path
4
+ from unittest.mock import patch
4
5
 
5
6
  import pytest
6
7
 
@@ -12,12 +13,26 @@ class TestBinarySplitter(unittest.TestCase):
12
13
  """Tests for the BinarySplitter class."""
13
14
 
14
15
  def setUp(self):
15
- model_name = "gpt-3.5-turbo"
16
+ model_name = "gpt-3.5-turbo-0125"
16
17
  self.binary_file = Path("janus/language/binary/_tests/hello")
17
18
  self.llm, _, _ = load_model(model_name)
18
19
  self.splitter = BinarySplitter(model=self.llm)
19
20
  os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
20
21
 
22
+ def test_setup(self):
23
+ """Test that the setup sets the environment variable correctly."""
24
+ with patch("os.getenv") as mock_getenv:
25
+ mock_getenv.return_value = "~/programs/ghidra_10.4_PUBLIC"
26
+ self.assertEqual(
27
+ os.getenv("GHIDRA_INSTALL_PATH"), "~/programs/ghidra_10.4_PUBLIC"
28
+ )
29
+ mock_getenv.assert_called_once_with("GHIDRA_INSTALL_PATH")
30
+
31
+ def test_initialization(self):
32
+ """Test that BinarySplitter is initialized correctly."""
33
+ self.assertIsInstance(self.splitter, BinarySplitter)
34
+ self.assertEqual(self.splitter.model, self.llm)
35
+
21
36
  @pytest.mark.ghidra(
22
37
  reason=(
23
38
  "No way to test this in CI w/o installing Ghidra, but want to keep here to "
@@ -29,7 +29,13 @@ class BinarySplitter(TreeSitterSplitter):
29
29
  with for transcoding.
30
30
  """
31
31
 
32
- def __init__(self, model: None | BaseLanguageModel = None, max_tokens: int = 4096):
32
+ def __init__(
33
+ self,
34
+ model: None | BaseLanguageModel = None,
35
+ max_tokens: int = 4096,
36
+ protected_node_types: tuple[str] = (),
37
+ prune_node_types: tuple[str] = (),
38
+ ):
33
39
  """Initialize a BinarySplitter instance.
34
40
 
35
41
  Arguments:
@@ -40,7 +46,8 @@ class BinarySplitter(TreeSitterSplitter):
40
46
  language="binary",
41
47
  model=model,
42
48
  max_tokens=max_tokens,
43
- use_placeholders=False,
49
+ protected_node_types=protected_node_types,
50
+ prune_node_types=prune_node_types,
44
51
  )
45
52
 
46
53
  def _execute_ghidra_script(self, cmd: list[str]) -> str:
@@ -131,7 +138,7 @@ class BinarySplitter(TreeSitterSplitter):
131
138
  code = self._get_decompilation(file)
132
139
 
133
140
  root = self._get_ast(code)
134
- self._set_identifiers(root, path)
141
+ self._set_identifiers(root, path.name)
135
142
  self._segment_leaves(root)
136
143
  self._merge_tree(root)
137
144
 
janus/language/block.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from functools import total_ordering
2
- from typing import ForwardRef, Hashable
2
+ from typing import ForwardRef, Hashable, Optional, Tuple
3
3
 
4
4
  from ..utils.logger import create_logger
5
5
  from .node import NodeType
@@ -14,7 +14,7 @@ class CodeBlock:
14
14
  Attributes:
15
15
  id: The id of the code block in the AST
16
16
  name: Descriptive name of node
17
- type: The type of the code block ('function', 'module', etc.). Defined in the
17
+ node_type: The type of the code block ('function', 'module', etc.). Defined in the
18
18
  language-specific modules.
19
19
  language: The language of the code block.
20
20
  text: The code block.
@@ -33,32 +33,32 @@ class CodeBlock:
33
33
  def __init__(
34
34
  self,
35
35
  id: Hashable,
36
- name: None | str,
37
- type: NodeType,
36
+ name: Optional[str],
37
+ node_type: NodeType,
38
38
  language: str,
39
- text: None | str,
40
- start_point: None | tuple[int, int],
41
- end_point: None | tuple[int, int],
42
- start_byte: None | int,
43
- end_byte: None | int,
39
+ text: Optional[str],
40
+ start_point: Optional[Tuple[int, int]],
41
+ end_point: Optional[Tuple[int, int]],
42
+ start_byte: Optional[int],
43
+ end_byte: Optional[int],
44
44
  tokens: int,
45
45
  children: list[ForwardRef("CodeBlock")],
46
- embedding_id: None | str = None,
47
- affixes: tuple[str, str] = ("", ""),
46
+ embedding_id: Optional[str] = None,
47
+ affixes: Tuple[str, str] = ("", ""),
48
48
  ) -> None:
49
49
  self.id: Hashable = id
50
- self.name: None | str = name
51
- self.type: NodeType = type
50
+ self.name: Optional[str] = name
51
+ self.node_type: NodeType = node_type
52
52
  self.language: str = language
53
- self.text: None | str = text
54
- self.start_point: None | tuple[int, int] = start_point
55
- self.end_point: None | tuple[int, int] = end_point
56
- self.start_byte: None | [int] = start_byte
57
- self.end_byte: None | [int] = end_byte
53
+ self.text: Optional[str] = text
54
+ self.start_point: Optional[Tuple[int, int]] = start_point
55
+ self.end_point: Optional[Tuple[int, int]] = end_point
56
+ self.start_byte: Optional[int] = start_byte
57
+ self.end_byte: Optional[int] = end_byte
58
58
  self.tokens: int = tokens
59
59
  self.children: list[ForwardRef("CodeBlock")] = sorted(children)
60
- self.embedding_id: None | [str] = embedding_id
61
- self.affixes: tuple[str, str] = affixes
60
+ self.embedding_id: Optional[str] = embedding_id
61
+ self.affixes: Tuple[str, str] = affixes
62
62
 
63
63
  self.complete = True
64
64
  self.omit_prefix = True
@@ -83,15 +83,7 @@ class CodeBlock:
83
83
 
84
84
  @property
85
85
  def complete_text(self) -> str:
86
- return f"{self.prefix}{self.text}{self.suffix}"
87
-
88
- @property
89
- def placeholder(self) -> str:
90
- return f"<<<{self.id}>>>"
91
-
92
- @property
93
- def complete_placeholder(self) -> str:
94
- return f"{self.prefix}<<<{self.id}>>>{self.suffix}"
86
+ return f"{self.prefix}{self.text or ''}{self.suffix}"
95
87
 
96
88
  @property
97
89
  def n_descendents(self) -> int:
@@ -146,6 +138,14 @@ class CodeBlock:
146
138
  self.affixes = (self.affixes[0], "")
147
139
  return suffix
148
140
 
141
+ def rebuild_text_from_children(self):
142
+ if self.children:
143
+ prefix = self.affixes[0] + self.children[0].pop_prefix()
144
+ suffix = self.children[-1].pop_suffix() + self.affixes[1]
145
+ self.text = "".join(c.complete_text for c in self.children)
146
+ self.affixes = (prefix, suffix)
147
+ self.tokens = sum(c.tokens for c in self.children)
148
+
149
149
  def tree_str(self, depth: int = 0) -> str:
150
150
  """A string representation of the tree with this block as the root
151
151
 
@@ -195,7 +195,7 @@ class TranslatedCodeBlock(CodeBlock):
195
195
  super().__init__(
196
196
  id=original.id,
197
197
  name=original.name,
198
- type=original.type,
198
+ node_type=original.node_type,
199
199
  language=language,
200
200
  text=None,
201
201
  start_point=original.start_point,
@@ -214,6 +214,7 @@ class TranslatedCodeBlock(CodeBlock):
214
214
  self.translated = False
215
215
  self.cost = 0.0
216
216
  self.retries = 0
217
+ self.processing_time = 0
217
218
 
218
219
  @property
219
220
  def total_cost(self) -> float:
janus/language/combine.py CHANGED
@@ -11,14 +11,14 @@ class Combiner(FileManager):
11
11
  """
12
12
 
13
13
  @staticmethod
14
- def combine(block: CodeBlock) -> None:
14
+ def combine(root: CodeBlock) -> None:
15
15
  """Combine the given block with its children.
16
16
 
17
17
  Arguments:
18
- block: The functional code block to combine with its children.
18
+ root: The functional code block to combine with its children.
19
19
  """
20
- Combiner.combine_children(block)
21
- block.omit_prefix = False
20
+ Combiner.combine_children(root)
21
+ root.omit_prefix = False
22
22
 
23
23
  @staticmethod
24
24
  def combine_children(block: CodeBlock) -> None:
@@ -48,16 +48,11 @@ class Combiner(FileManager):
48
48
  block.complete = children_complete
49
49
  return
50
50
 
51
- # Replace all placeholders
52
51
  missing_children = []
53
52
  for child in block.children:
54
53
  if isinstance(block, TranslatedCodeBlock) and not child.translated:
55
54
  missing_children.append(child)
56
55
  continue
57
- if not Combiner.contains_child(block.text, child):
58
- missing_children.append(child)
59
- continue
60
- block.text = block.text.replace(child.placeholder, child.text)
61
56
 
62
57
  if missing_children:
63
58
  missing_ids = [c.id for c in missing_children]
@@ -66,36 +61,33 @@ class Combiner(FileManager):
66
61
  block.children = missing_children
67
62
  block.complete = children_complete and not missing_children
68
63
 
64
+
65
+ class JsonCombiner(Combiner):
69
66
  @staticmethod
70
- def contains_child(code: str, child: CodeBlock) -> bool:
71
- """Determine whether the given code contains a placeholder for the given
72
- child block.
67
+ def combine(root: CodeBlock) -> None:
68
+ """Combine the given block with its children.
73
69
 
74
70
  Arguments:
75
- code: The code to check for the placeholder
76
- child: The child block to check for
77
-
78
- Returns:
79
- Whether the given code contains a placeholder for the given child
80
- block.
71
+ root: The functional code block to combine with its children.
81
72
  """
82
- return code is None or child.placeholder in code
83
-
73
+ stack = [root]
74
+ while stack:
75
+ block = stack.pop()
76
+ if block.children:
77
+ stack.extend(block.children)
78
+ block.affixes = ("", "")
79
+ else:
80
+ block.affixes = ("\n", "\n")
81
+ super(JsonCombiner, JsonCombiner).combine(root)
82
+
83
+
84
+ class ChunkCombiner(Combiner):
84
85
  @staticmethod
85
- def count_missing(input_block: CodeBlock, output_code: str) -> int:
86
- """Return the number of children of input_block who are not represented
87
- in output_code with a placeholder
86
+ def combine(root: CodeBlock) -> None:
87
+ """A combiner which doesn't actually combine the code blocks,
88
+ instead preserving children
88
89
 
89
90
  Arguments:
90
- input_block: The block to check for missing children
91
- output_code: The code to check for placeholders
92
-
93
- Returns:
94
- The number of children of input_block who are not represented in
95
- output_code with a placeholder
91
+ root: The functional code block to combine with its children.
96
92
  """
97
- missing_children = 0
98
- for child in input_block.children:
99
- if not Combiner.contains_child(output_code, child):
100
- missing_children += 1
101
- return missing_children
93
+ return root
@@ -11,7 +11,7 @@ class TestMumpsSplitter(unittest.TestCase):
11
11
 
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
- model_name = "gpt-3.5-turbo"
14
+ model_name = "gpt-3.5-turbo-0125"
15
15
  llm, _, _ = load_model(model_name)
16
16
  self.splitter = MumpsSplitter(model=llm)
17
17
  self.combiner = Combiner(language="mumps")
@@ -20,7 +20,7 @@ class TestMumpsSplitter(unittest.TestCase):
20
20
  def test_split(self):
21
21
  """Test the split method."""
22
22
  tree_root = self.splitter.split(self.test_file)
23
- self.assertEqual(len(tree_root.children), 6)
23
+ self.assertEqual(len(tree_root.children), 22)
24
24
  self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
25
25
  self.assertFalse(tree_root.complete)
26
26
  self.combiner.combine_children(tree_root)
@@ -1,5 +1,4 @@
1
1
  import re
2
- from pathlib import Path
3
2
 
4
3
  from langchain.schema.language_model import BaseLanguageModel
5
4
 
@@ -43,7 +42,13 @@ class MumpsSplitter(Splitter):
43
42
  re.VERBOSE | re.DOTALL,
44
43
  )
45
44
 
46
- def __init__(self, model: None | BaseLanguageModel = None, max_tokens: int = 4096):
45
+ def __init__(
46
+ self,
47
+ model: None | BaseLanguageModel = None,
48
+ max_tokens: int = 4096,
49
+ protected_node_types: tuple[str] = ("routine_definition",),
50
+ prune_node_types: tuple[str] = (),
51
+ ):
47
52
  """Initialize a MumpsSplitter instance.
48
53
 
49
54
  Arguments:
@@ -53,17 +58,18 @@ class MumpsSplitter(Splitter):
53
58
  language="mumps",
54
59
  model=model,
55
60
  max_tokens=max_tokens,
56
- use_placeholders=False,
61
+ protected_node_types=protected_node_types,
62
+ prune_node_types=prune_node_types,
57
63
  )
58
64
 
59
65
  # MUMPS code tends to take about 2/3 the space of Python
60
66
  self.max_tokens: int = int(max_tokens * 2 / 5)
61
67
 
62
- def _set_identifiers(self, root: CodeBlock, path: Path):
68
+ def _set_identifiers(self, root: CodeBlock, name: str):
63
69
  stack = [root]
64
70
  while stack:
65
71
  node = stack.pop()
66
- node.name = f"{path.name}:{node.id}"
72
+ node.name = f"{name}:{node.id}"
67
73
  stack.extend(node.children)
68
74
 
69
75
  def _get_ast(self, code: str) -> CodeBlock:
@@ -104,15 +110,19 @@ class MumpsSplitter(Splitter):
104
110
  start_byte=start_byte,
105
111
  end_byte=end_byte,
106
112
  affixes=(prefix, suffix),
107
- type=NodeType("subroutine"),
113
+ node_type=NodeType("routine_definition"),
108
114
  children=[],
109
115
  language=self.language,
110
116
  tokens=self._count_tokens(chunk),
111
117
  )
118
+ self._split_into_lines(node)
119
+ for line_node in node.children:
120
+ self._split_comment(line_node)
121
+
112
122
  children.append(node)
113
123
 
114
- start_byte = end_byte + len(bytes(suffix, "utf-8"))
115
- start_line = end_line + suffix.count("\n")
124
+ start_byte = end_byte
125
+ start_line = end_line
116
126
 
117
127
  return CodeBlock(
118
128
  text=code,
@@ -122,8 +132,82 @@ class MumpsSplitter(Splitter):
122
132
  end_point=(code.count("\n"), 0),
123
133
  start_byte=0,
124
134
  end_byte=len(bytes(code, "utf-8")),
125
- type=NodeType("routine"),
135
+ node_type=NodeType("routine"),
126
136
  children=children,
127
137
  language=self.language,
128
138
  tokens=self._count_tokens(code),
129
139
  )
140
+
141
+ @staticmethod
142
+ def comment_start(line: str) -> int:
143
+ first_semicolon = line.find(";")
144
+ if first_semicolon < 0:
145
+ return first_semicolon
146
+
147
+ # In mumps, quotes are escaped by doubling them (""). Single quote
148
+ # characters are logical not operators, not quotes
149
+ n_quotes = line[:first_semicolon].replace('""', "").count('"')
150
+
151
+ # If the number of quotes prior to the first semicolon is even, then
152
+ # that semicolon is not part of a quote (and therefore starts a comment)
153
+ if n_quotes % 2 == 0:
154
+ return first_semicolon
155
+
156
+ last_semicolon = first_semicolon
157
+ while (next_semicolon := line.find(";", last_semicolon + 1)) > 0:
158
+ n_quotes = line[last_semicolon:next_semicolon].replace('""', "").count('"')
159
+
160
+ # If the number of quotes in this chunk is odd, the total number
161
+ # of them up to this point is even, and the next semicolon begins
162
+ # the comment
163
+ if n_quotes % 2:
164
+ return next_semicolon
165
+
166
+ last_semicolon = next_semicolon
167
+
168
+ return -1
169
+
170
+ def _split_comment(self, line_node: CodeBlock):
171
+ comment_start = self.comment_start(line_node.text)
172
+ if comment_start < 0:
173
+ line_node.node_type = NodeType("code_line")
174
+ return
175
+
176
+ code = line_node.text[:comment_start]
177
+ if not code.strip():
178
+ line_node.node_type = NodeType("comment")
179
+ return
180
+
181
+ comment = line_node.text[comment_start:]
182
+ (l0, c0), (l1, c1) = line_node.start_point, line_node.end_point
183
+ prefix, suffix = line_node.affixes
184
+ code_bytes = len(bytes(code, "utf-8"))
185
+
186
+ line_node.children = [
187
+ CodeBlock(
188
+ text=code,
189
+ name=f"{line_node.name}-code",
190
+ id=f"{line_node.name}-code",
191
+ start_point=(l0, c0),
192
+ end_point=(l1, comment_start),
193
+ start_byte=line_node.start_byte,
194
+ end_byte=line_node.start_byte + code_bytes,
195
+ node_type=NodeType("code_line"),
196
+ children=[],
197
+ language=line_node.language,
198
+ tokens=self._count_tokens(code),
199
+ ),
200
+ CodeBlock(
201
+ text=comment,
202
+ name=f"{line_node.name}-comment",
203
+ id=f"{line_node.name}-comment",
204
+ start_point=(l0, c0 + comment_start),
205
+ end_point=(l1, c1),
206
+ start_byte=line_node.start_byte + code_bytes,
207
+ end_byte=line_node.end_byte,
208
+ node_type=NodeType("comment"),
209
+ children=[],
210
+ language=self.language,
211
+ tokens=self._count_tokens(comment),
212
+ ),
213
+ ]
@@ -0,0 +1,4 @@
1
+ from .basic_splitter import FileSplitter
2
+ from .chunk_splitter import ChunkSplitter
3
+ from .simple_ast import FlexibleTreeSitterSplitter, StrictTreeSitterSplitter
4
+ from .tag_splitter import TagSplitter
@@ -0,0 +1,14 @@
1
+ from janus.language.block import CodeBlock
2
+ from janus.language.naive.chunk_splitter import ChunkSplitter
3
+ from janus.language.naive.registry import register_splitter
4
+ from janus.language.splitter import FileSizeError
5
+
6
+
7
+ @register_splitter("file")
8
+ class FileSplitter(ChunkSplitter):
9
+ """
10
+ Splits based on the entire file of the code
11
+ """
12
+
13
+ def _split_into_lines(self, node: CodeBlock):
14
+ raise FileSizeError("File too large for basic splitter")
@@ -0,0 +1,26 @@
1
+ from janus.language.block import CodeBlock
2
+ from janus.language.naive.registry import register_splitter
3
+ from janus.language.node import NodeType
4
+ from janus.language.splitter import Splitter
5
+
6
+
7
+ @register_splitter("chunk")
8
+ class ChunkSplitter(Splitter):
9
+ """
10
+ Splits into fixed chunk sizes without parsing
11
+ """
12
+
13
+ def _get_ast(self, code: str) -> CodeBlock:
14
+ return CodeBlock(
15
+ text=code,
16
+ name="root",
17
+ id="root",
18
+ start_point=(0, 0),
19
+ end_point=(code.count("\n"), 0),
20
+ start_byte=0,
21
+ end_byte=len(bytes(code, "utf-8")),
22
+ node_type=NodeType("program"),
23
+ children=[],
24
+ language=self.language,
25
+ tokens=self._count_tokens(code),
26
+ )
@@ -0,0 +1,13 @@
1
+ from typing import Callable, Dict
2
+
3
+ from janus.language.splitter import Splitter
4
+
5
+ CUSTOM_SPLITTERS: Dict[str, Callable[..., Splitter]] = dict()
6
+
7
+
8
+ def register_splitter(name: str):
9
+ def callback(splitter):
10
+ CUSTOM_SPLITTERS[name] = splitter
11
+ return splitter
12
+
13
+ return callback
@@ -0,0 +1,18 @@
1
+ from janus.language.naive.registry import register_splitter
2
+ from janus.language.treesitter import TreeSitterSplitter
3
+ from janus.utils.enums import LANGUAGES
4
+
5
+
6
+ @register_splitter("ast-flex")
7
+ class FlexibleTreeSitterSplitter(TreeSitterSplitter):
8
+ pass
9
+
10
+
11
+ @register_splitter("ast-strict")
12
+ class StrictTreeSitterSplitter(TreeSitterSplitter):
13
+ def __init__(self, language: str, **kwargs):
14
+ kwargs.update(
15
+ protected_node_types=(LANGUAGES[language]["functional_node_type"],),
16
+ prune_unprotected=True,
17
+ )
18
+ super().__init__(language=language, **kwargs)
@@ -0,0 +1,61 @@
1
+ from janus.language.block import CodeBlock
2
+ from janus.language.naive.registry import register_splitter
3
+ from janus.language.node import NodeType
4
+ from janus.language.splitter import Splitter
5
+
6
+
7
+ @register_splitter("tag")
8
+ class TagSplitter(Splitter):
9
+ """
10
+ Splits code by tags inserted into code
11
+ """
12
+
13
+ def __init__(self, tag: str, *args, **kwargs):
14
+ kwargs.update(protected_node_types=("chunk",))
15
+ super().__init__(*args, **kwargs)
16
+ self._tag = f"\n{tag}\n"
17
+
18
+ def _get_ast(self, code: str) -> CodeBlock:
19
+ chunks = code.split(self._tag)
20
+ children = []
21
+ start_line = 0
22
+ start_byte = 0
23
+ for i, chunk in enumerate(chunks):
24
+ prefix = suffix = self._tag
25
+ if i == 0:
26
+ prefix = ""
27
+ if i == len(chunks) - 1:
28
+ suffix = ""
29
+ end_byte = start_byte + len(bytes(chunk, "utf-8"))
30
+ end_line = start_line + chunk.count("\n")
31
+ end_char = len(chunk) - chunk.rfind("\n") - 1
32
+ node = CodeBlock(
33
+ text=chunk,
34
+ name=f"Chunk {i}",
35
+ id=f"Chunk {i}",
36
+ start_point=(start_line, 0),
37
+ end_point=(end_line, end_char),
38
+ start_byte=start_byte,
39
+ end_byte=end_byte,
40
+ affixes=(prefix, suffix),
41
+ node_type=NodeType("chunk"),
42
+ children=[],
43
+ language=self.language,
44
+ tokens=self._count_tokens(chunk),
45
+ )
46
+ children.append(node)
47
+ start_line = end_line
48
+ start_byte = end_byte
49
+ return CodeBlock(
50
+ text=code,
51
+ name="root",
52
+ id="root",
53
+ start_point=(0, 0),
54
+ end_point=(code.count("\n"), 0),
55
+ start_byte=0,
56
+ end_byte=len(bytes(code, "utf-8")),
57
+ node_type=NodeType("program"),
58
+ children=children,
59
+ language=self.language,
60
+ tokens=self._count_tokens(code),
61
+ )