janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -51
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +50 -0
- janus/converter/translate.py +108 -0
- janus/embedding/_tests/test_collections.py +2 -2
- janus/language/_tests/test_splitter.py +1 -1
- janus/language/alc/__init__.py +1 -0
- janus/language/alc/_tests/__init__.py +0 -0
- janus/language/alc/_tests/test_alc.py +28 -0
- janus/language/alc/alc.py +87 -0
- janus/language/block.py +4 -2
- janus/language/combine.py +0 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/__init__.py +1 -1
- janus/language/naive/basic_splitter.py +4 -4
- janus/language/naive/chunk_splitter.py +4 -4
- janus/language/naive/registry.py +1 -1
- janus/language/naive/simple_ast.py +23 -12
- janus/language/naive/tag_splitter.py +4 -4
- janus/language/splitter.py +10 -4
- janus/language/treesitter/treesitter.py +26 -8
- janus/llm/model_callbacks.py +34 -37
- janus/llm/models_info.py +16 -3
- janus/metrics/_tests/test_llm.py +2 -3
- janus/metrics/_tests/test_rouge_score.py +1 -1
- janus/metrics/_tests/test_similarity_score.py +1 -1
- janus/metrics/complexity_metrics.py +3 -4
- janus/metrics/metric.py +3 -4
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +67 -7
- janus/utils/enums.py +6 -5
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
- janus/converter.py +0 -158
- janus/translate.py +0 -981
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,124 @@
|
|
1
|
+
import json
|
2
|
+
from copy import deepcopy
|
3
|
+
|
4
|
+
from janus.converter.converter import run_if_changed
|
5
|
+
from janus.converter.document import Documenter
|
6
|
+
from janus.language.block import TranslatedCodeBlock
|
7
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
8
|
+
from janus.utils.logger import create_logger
|
9
|
+
|
10
|
+
log = create_logger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class DiagramGenerator(Documenter):
|
14
|
+
"""DiagramGenerator
|
15
|
+
|
16
|
+
A class that translates code from one programming language to a set of diagrams.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
diagram_type="Activity",
|
22
|
+
add_documentation=False,
|
23
|
+
**kwargs,
|
24
|
+
) -> None:
|
25
|
+
"""Initialize the DiagramGenerator class
|
26
|
+
|
27
|
+
Arguments:
|
28
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
29
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
30
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
31
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
32
|
+
source_language: The source programming language.
|
33
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
34
|
+
db_path: path to chroma database
|
35
|
+
db_config: database configuraiton
|
36
|
+
diagram_type: type of PLANTUML diagram to generate
|
37
|
+
"""
|
38
|
+
super().__init__(**kwargs)
|
39
|
+
self._diagram_type = diagram_type
|
40
|
+
self._add_documentation = add_documentation
|
41
|
+
self._documenter = None
|
42
|
+
if add_documentation:
|
43
|
+
self._diagram_prompt_template_name = "diagram_with_documentation"
|
44
|
+
else:
|
45
|
+
self._diagram_prompt_template_name = "diagram"
|
46
|
+
self._load_diagram_prompt_engine()
|
47
|
+
|
48
|
+
def _add_translation(self, block: TranslatedCodeBlock) -> None:
|
49
|
+
"""Given an "empty" `TranslatedCodeBlock`, translate the code represented in
|
50
|
+
`block.original`, setting the relevant fields in the translated block. The
|
51
|
+
`TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
|
52
|
+
translates *only* the code for this block, not its children.
|
53
|
+
|
54
|
+
Arguments:
|
55
|
+
block: An empty `TranslatedCodeBlock`
|
56
|
+
"""
|
57
|
+
if block.translated:
|
58
|
+
return
|
59
|
+
|
60
|
+
if block.original.text is None:
|
61
|
+
block.translated = True
|
62
|
+
return
|
63
|
+
|
64
|
+
if self._add_documentation:
|
65
|
+
documentation_block = deepcopy(block)
|
66
|
+
super()._add_translation(documentation_block)
|
67
|
+
if not documentation_block.translated:
|
68
|
+
message = "Error: unable to produce documentation for code block"
|
69
|
+
log.info(message)
|
70
|
+
raise ValueError(message)
|
71
|
+
documentation = json.loads(documentation_block.text)["docstring"]
|
72
|
+
|
73
|
+
if self._llm is None:
|
74
|
+
message = (
|
75
|
+
"Model not configured correctly, cannot translate. Try setting "
|
76
|
+
"the model"
|
77
|
+
)
|
78
|
+
log.error(message)
|
79
|
+
raise ValueError(message)
|
80
|
+
|
81
|
+
log.debug(f"[{block.name}] Translating...")
|
82
|
+
log.debug(f"[{block.name}] Input text:\n{block.original.text}")
|
83
|
+
|
84
|
+
self._parser.set_reference(block.original)
|
85
|
+
|
86
|
+
query_and_parse = self.diagram_prompt | self._llm | self._parser
|
87
|
+
|
88
|
+
if self._add_documentation:
|
89
|
+
block.text = query_and_parse.invoke(
|
90
|
+
{
|
91
|
+
"SOURCE_CODE": block.original.text,
|
92
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
93
|
+
"DOCUMENTATION": documentation,
|
94
|
+
}
|
95
|
+
)
|
96
|
+
else:
|
97
|
+
block.text = query_and_parse.invoke(
|
98
|
+
{
|
99
|
+
"SOURCE_CODE": block.original.text,
|
100
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
101
|
+
}
|
102
|
+
)
|
103
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
104
|
+
block.translated = True
|
105
|
+
|
106
|
+
log.debug(f"[{block.name}] Output code:\n{block.text}")
|
107
|
+
|
108
|
+
@run_if_changed(
|
109
|
+
"_diagram_prompt_template_name",
|
110
|
+
"_source_language",
|
111
|
+
)
|
112
|
+
def _load_diagram_prompt_engine(self) -> None:
|
113
|
+
"""Load the prompt engine according to this instance's attributes.
|
114
|
+
|
115
|
+
If the relevant fields have not been changed since the last time this method was
|
116
|
+
called, nothing happens.
|
117
|
+
"""
|
118
|
+
self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
|
119
|
+
source_language=self._source_language,
|
120
|
+
target_language="text",
|
121
|
+
target_version=None,
|
122
|
+
prompt_template=self._diagram_prompt_template_name,
|
123
|
+
)
|
124
|
+
self.diagram_prompt = self._diagram_prompt_engine.prompt
|
@@ -0,0 +1,131 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
from janus.converter.converter import Converter
|
6
|
+
from janus.language.block import TranslatedCodeBlock
|
7
|
+
from janus.language.combine import JsonCombiner
|
8
|
+
from janus.parsers.doc_parser import (
|
9
|
+
MadlibsDocumentationParser,
|
10
|
+
MultiDocumentationParser,
|
11
|
+
)
|
12
|
+
from janus.utils.enums import LANGUAGES
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class Documenter(Converter):
|
19
|
+
def __init__(
|
20
|
+
self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
|
21
|
+
):
|
22
|
+
kwargs.update(source_language=source_language)
|
23
|
+
super().__init__(**kwargs)
|
24
|
+
self.set_prompt("document")
|
25
|
+
|
26
|
+
if drop_comments:
|
27
|
+
comment_node_type = LANGUAGES[source_language].get(
|
28
|
+
"comment_node_type", "comment"
|
29
|
+
)
|
30
|
+
self.set_prune_node_types((comment_node_type,))
|
31
|
+
|
32
|
+
self._load_parameters()
|
33
|
+
|
34
|
+
|
35
|
+
class MultiDocumenter(Documenter):
|
36
|
+
def __init__(self, **kwargs):
|
37
|
+
super().__init__(**kwargs)
|
38
|
+
self.set_prompt("multidocument")
|
39
|
+
self._combiner = JsonCombiner()
|
40
|
+
self._parser = MultiDocumentationParser()
|
41
|
+
|
42
|
+
|
43
|
+
class MadLibsDocumenter(Documenter):
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
comments_per_request: int | None = None,
|
47
|
+
**kwargs,
|
48
|
+
) -> None:
|
49
|
+
kwargs.update(drop_comments=False)
|
50
|
+
super().__init__(**kwargs)
|
51
|
+
self.set_prompt("document_madlibs")
|
52
|
+
self._combiner = JsonCombiner()
|
53
|
+
self._parser = MadlibsDocumentationParser()
|
54
|
+
|
55
|
+
self.comments_per_request = comments_per_request
|
56
|
+
|
57
|
+
def _add_translation(self, block: TranslatedCodeBlock):
|
58
|
+
if block.translated:
|
59
|
+
return
|
60
|
+
|
61
|
+
if block.original.text is None:
|
62
|
+
block.translated = True
|
63
|
+
return
|
64
|
+
|
65
|
+
if self.comments_per_request is None:
|
66
|
+
return super()._add_translation(block)
|
67
|
+
|
68
|
+
comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
|
69
|
+
comments = list(
|
70
|
+
re.finditer(
|
71
|
+
comment_pattern,
|
72
|
+
block.original.text,
|
73
|
+
)
|
74
|
+
)
|
75
|
+
|
76
|
+
if not comments:
|
77
|
+
log.info(f"[{block.name}] Skipping commentless block")
|
78
|
+
block.translated = True
|
79
|
+
block.text = None
|
80
|
+
block.complete = True
|
81
|
+
return
|
82
|
+
|
83
|
+
if len(comments) <= self.comments_per_request:
|
84
|
+
return super()._add_translation(block)
|
85
|
+
|
86
|
+
comment_group_indices = list(range(0, len(comments), self.comments_per_request))
|
87
|
+
log.debug(
|
88
|
+
f"[{block.name}] Block contains more than {self.comments_per_request}"
|
89
|
+
f" comments, splitting {len(comments)} comments into"
|
90
|
+
f" {len(comment_group_indices)} groups"
|
91
|
+
)
|
92
|
+
|
93
|
+
block.processing_time = 0
|
94
|
+
block.cost = 0
|
95
|
+
block.retries = 0
|
96
|
+
obj = {}
|
97
|
+
for i in range(0, len(comments), self.comments_per_request):
|
98
|
+
# Split the text into the section containing comments of interest,
|
99
|
+
# all the text prior to those comments, and all the text after them
|
100
|
+
working_comments = comments[i : i + self.comments_per_request]
|
101
|
+
start_idx = working_comments[0].start()
|
102
|
+
end_idx = working_comments[-1].end()
|
103
|
+
prefix = block.original.text[:start_idx]
|
104
|
+
keeper = block.original.text[start_idx:end_idx]
|
105
|
+
suffix = block.original.text[end_idx:]
|
106
|
+
|
107
|
+
# Strip all comment placeholders outside of the section of interest
|
108
|
+
prefix = re.sub(comment_pattern, "", prefix)
|
109
|
+
suffix = re.sub(comment_pattern, "", suffix)
|
110
|
+
|
111
|
+
# Build a new TranslatedBlock using the new working text
|
112
|
+
working_copy = deepcopy(block.original)
|
113
|
+
working_copy.text = prefix + keeper + suffix
|
114
|
+
working_block = TranslatedCodeBlock(working_copy, self._target_language)
|
115
|
+
|
116
|
+
# Run the LLM on the working text
|
117
|
+
super()._add_translation(working_block)
|
118
|
+
|
119
|
+
# Update metadata to include for all runs
|
120
|
+
block.retries += working_block.retries
|
121
|
+
block.cost += working_block.cost
|
122
|
+
block.processing_time += working_block.processing_time
|
123
|
+
|
124
|
+
# Update the output text to merge this section's output in
|
125
|
+
out_text = self._parser.parse(working_block.text)
|
126
|
+
obj.update(json.loads(out_text))
|
127
|
+
|
128
|
+
self._parser.set_reference(block.original)
|
129
|
+
block.text = self._parser.parse(json.dumps(obj))
|
130
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
131
|
+
block.translated = True
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from janus.converter.converter import Converter
|
2
|
+
from janus.language.combine import JsonCombiner
|
3
|
+
from janus.parsers.eval_parser import EvaluationParser
|
4
|
+
from janus.utils.logger import create_logger
|
5
|
+
|
6
|
+
log = create_logger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
class Evaluator(Converter):
|
10
|
+
def __init__(self, **kwargs):
|
11
|
+
super().__init__(**kwargs)
|
12
|
+
self.set_prompt("evaluate")
|
13
|
+
self._combiner = JsonCombiner()
|
14
|
+
self._parser = EvaluationParser()
|
15
|
+
self._load_parameters()
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import json
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from janus.converter.document import Documenter
|
5
|
+
from janus.language.block import TranslatedCodeBlock
|
6
|
+
from janus.language.combine import ChunkCombiner
|
7
|
+
from janus.parsers.reqs_parser import RequirementsParser
|
8
|
+
from janus.utils.logger import create_logger
|
9
|
+
|
10
|
+
log = create_logger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class RequirementsDocumenter(Documenter):
|
14
|
+
"""RequirementsGenerator
|
15
|
+
|
16
|
+
A class that translates code from one programming language to its requirements.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, **kwargs):
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
self.set_prompt("requirements")
|
22
|
+
self._combiner = ChunkCombiner()
|
23
|
+
self._parser = RequirementsParser()
|
24
|
+
|
25
|
+
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
26
|
+
"""Save a file to disk.
|
27
|
+
|
28
|
+
Arguments:
|
29
|
+
block: The `CodeBlock` to save to a file.
|
30
|
+
"""
|
31
|
+
output_list = list()
|
32
|
+
# For each chunk of code, get generation metadata, the text of the code,
|
33
|
+
# and the LLM generated requirements
|
34
|
+
for child in block.children:
|
35
|
+
code = child.original.text
|
36
|
+
requirements = self._parser.parse_combined_output(child.complete_text)
|
37
|
+
metadata = dict(
|
38
|
+
retries=child.total_retries,
|
39
|
+
cost=child.total_cost,
|
40
|
+
processing_time=child.processing_time,
|
41
|
+
)
|
42
|
+
# Put them all in a top level 'output' key
|
43
|
+
output_list.append(
|
44
|
+
dict(metadata=metadata, code=code, requirements=requirements)
|
45
|
+
)
|
46
|
+
obj = dict(
|
47
|
+
output=output_list,
|
48
|
+
)
|
49
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
50
|
+
out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
@@ -0,0 +1,108 @@
|
|
1
|
+
from janus.converter.converter import Converter, run_if_changed
|
2
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
3
|
+
from janus.parsers.code_parser import CodeParser
|
4
|
+
from janus.prompts.prompt import SAME_OUTPUT
|
5
|
+
from janus.utils.enums import LANGUAGES
|
6
|
+
from janus.utils.logger import create_logger
|
7
|
+
|
8
|
+
log = create_logger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class Translator(Converter):
|
12
|
+
"""A class that translates code from one programming language to another."""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
target_language: str = "python",
|
17
|
+
target_version: str | None = "3.10",
|
18
|
+
**kwargs,
|
19
|
+
) -> None:
|
20
|
+
"""Initialize a Translator instance.
|
21
|
+
|
22
|
+
Arguments:
|
23
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
24
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
25
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
26
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
27
|
+
source_language: The source programming language.
|
28
|
+
target_language: The target programming language.
|
29
|
+
target_version: The target version of the target programming language.
|
30
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
31
|
+
max_tokens: The maximum number of tokens the model will take in.
|
32
|
+
If unspecificed, model's default max will be used.
|
33
|
+
prompt_template: name of prompt template directory
|
34
|
+
(see janus/prompts/templates) or path to a directory.
|
35
|
+
"""
|
36
|
+
super().__init__(**kwargs)
|
37
|
+
|
38
|
+
self._target_version: str | None
|
39
|
+
|
40
|
+
self.set_target_language(
|
41
|
+
target_language=target_language,
|
42
|
+
target_version=target_version,
|
43
|
+
)
|
44
|
+
|
45
|
+
self._load_parameters()
|
46
|
+
|
47
|
+
def _load_parameters(self) -> None:
|
48
|
+
self._load_parser()
|
49
|
+
super()._load_parameters()
|
50
|
+
|
51
|
+
def set_target_language(
|
52
|
+
self, target_language: str, target_version: str | None
|
53
|
+
) -> None:
|
54
|
+
"""Validate and set the target language.
|
55
|
+
|
56
|
+
The affected objects will not be updated until translate() is called.
|
57
|
+
|
58
|
+
Arguments:
|
59
|
+
target_language: The target programming language.
|
60
|
+
target_version: The target version of the target programming language.
|
61
|
+
"""
|
62
|
+
target_language = target_language.lower()
|
63
|
+
if target_language not in LANGUAGES:
|
64
|
+
raise ValueError(
|
65
|
+
f"Invalid target language: {target_language}. "
|
66
|
+
"Valid target languages are found in `janus.utils.enums.LANGUAGES`."
|
67
|
+
)
|
68
|
+
self._target_language = target_language
|
69
|
+
self._target_version = target_version
|
70
|
+
self._target_suffix = f".{LANGUAGES[target_language]['suffix']}"
|
71
|
+
|
72
|
+
@run_if_changed(
|
73
|
+
"_prompt_template_name",
|
74
|
+
"_source_language",
|
75
|
+
"_target_language",
|
76
|
+
"_target_version",
|
77
|
+
"_model_name",
|
78
|
+
)
|
79
|
+
def _load_prompt(self) -> None:
|
80
|
+
"""Load the prompt according to this instance's attributes.
|
81
|
+
|
82
|
+
If the relevant fields have not been changed since the last time this
|
83
|
+
method was called, nothing happens.
|
84
|
+
"""
|
85
|
+
if self._prompt_template_name in SAME_OUTPUT:
|
86
|
+
if self._target_language != self._source_language:
|
87
|
+
raise ValueError(
|
88
|
+
f"Prompt template ({self._prompt_template_name}) suggests "
|
89
|
+
f"source and target languages should match, but do not "
|
90
|
+
f"({self._source_language} != {self._target_language})"
|
91
|
+
)
|
92
|
+
|
93
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
|
94
|
+
source_language=self._source_language,
|
95
|
+
target_language=self._target_language,
|
96
|
+
target_version=self._target_version,
|
97
|
+
prompt_template=self._prompt_template_name,
|
98
|
+
)
|
99
|
+
self._prompt = prompt_engine.prompt
|
100
|
+
|
101
|
+
@run_if_changed("_target_language")
|
102
|
+
def _load_parser(self) -> None:
|
103
|
+
"""Load the parser according to this instance's attributes.
|
104
|
+
|
105
|
+
If the relevant fields have not been changed since the last time this
|
106
|
+
method was called, nothing happens.
|
107
|
+
"""
|
108
|
+
self._parser = CodeParser(language=self._target_language)
|
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
|
4
4
|
|
5
5
|
import pytest
|
6
6
|
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from ...utils.enums import EmbeddingType
|
8
|
+
from ..collections import Collections
|
9
9
|
|
10
10
|
|
11
11
|
class TestCollections(unittest.TestCase):
|
@@ -0,0 +1 @@
|
|
1
|
+
from .alc import AlcCombiner, AlcSplitter
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import unittest
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from ....llm import load_model
|
5
|
+
from ...combine import Combiner
|
6
|
+
from ..alc import AlcSplitter
|
7
|
+
|
8
|
+
|
9
|
+
class TestAlcSplitter(unittest.TestCase):
|
10
|
+
"""Tests for the Splitter class."""
|
11
|
+
|
12
|
+
def setUp(self):
|
13
|
+
"""Set up the tests."""
|
14
|
+
model_name = "gpt-3.5-turbo-0125"
|
15
|
+
llm, _, _ = load_model(model_name)
|
16
|
+
self.splitter = AlcSplitter(model=llm)
|
17
|
+
self.combiner = Combiner(language="ibmhlasm")
|
18
|
+
self.test_file = Path("janus/language/alc/_tests/alc.asm")
|
19
|
+
|
20
|
+
def test_split(self):
|
21
|
+
"""Test the split method."""
|
22
|
+
tree_root = self.splitter.split(self.test_file)
|
23
|
+
self.assertEqual(tree_root.n_descendents, 34)
|
24
|
+
self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
|
25
|
+
self.assertFalse(tree_root.complete)
|
26
|
+
self.combiner.combine_children(tree_root)
|
27
|
+
self.assertTrue(tree_root.complete)
|
28
|
+
self.assertEqual(tree_root.complete_text, self.test_file.read_text())
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from langchain.schema.language_model import BaseLanguageModel
|
2
|
+
|
3
|
+
from ...utils.logger import create_logger
|
4
|
+
from ..block import CodeBlock
|
5
|
+
from ..combine import Combiner
|
6
|
+
from ..node import NodeType
|
7
|
+
from ..treesitter import TreeSitterSplitter
|
8
|
+
|
9
|
+
log = create_logger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class AlcCombiner(Combiner):
|
13
|
+
"""A class that combines code blocks into ALC files."""
|
14
|
+
|
15
|
+
def __init__(self) -> None:
|
16
|
+
"""Initialize a AlcCombiner instance."""
|
17
|
+
super().__init__("ibmhlasm")
|
18
|
+
|
19
|
+
|
20
|
+
class AlcSplitter(TreeSitterSplitter):
|
21
|
+
"""A class for splitting ALC code into functional blocks to prompt
|
22
|
+
with for transcoding.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
model: None | BaseLanguageModel = None,
|
28
|
+
max_tokens: int = 4096,
|
29
|
+
protected_node_types: tuple[str, ...] = (),
|
30
|
+
prune_node_types: tuple[str, ...] = (),
|
31
|
+
prune_unprotected: bool = False,
|
32
|
+
):
|
33
|
+
"""Initialize a AlcSplitter instance.
|
34
|
+
|
35
|
+
Arguments:
|
36
|
+
max_tokens: The maximum number of tokens supported by the model
|
37
|
+
"""
|
38
|
+
super().__init__(
|
39
|
+
language="ibmhlasm",
|
40
|
+
model=model,
|
41
|
+
max_tokens=max_tokens,
|
42
|
+
protected_node_types=protected_node_types,
|
43
|
+
prune_node_types=prune_node_types,
|
44
|
+
prune_unprotected=prune_unprotected,
|
45
|
+
)
|
46
|
+
|
47
|
+
def _get_ast(self, code: str) -> CodeBlock:
|
48
|
+
root = super()._get_ast(code)
|
49
|
+
|
50
|
+
# Current treesitter implementation does not nest csects and dsects
|
51
|
+
# The loop below nests nodes following csect/dsect instructions into
|
52
|
+
# the children of that instruction
|
53
|
+
sect_types = {"csect_instruction", "dsect_instruction"}
|
54
|
+
queue: list[CodeBlock] = [root]
|
55
|
+
while queue:
|
56
|
+
block = queue.pop(0)
|
57
|
+
|
58
|
+
# Search this children for csects and dsects. Create a list of groups
|
59
|
+
# where each group is a csect or dsect, starting with the csect/dsect
|
60
|
+
# instruction and containing all the subsequent nodes up until the
|
61
|
+
# next csect or dsect instruction
|
62
|
+
sects: list[list[CodeBlock]] = [[]]
|
63
|
+
for c in block.children:
|
64
|
+
if c.node_type in sect_types:
|
65
|
+
sects.append([c])
|
66
|
+
else:
|
67
|
+
sects[-1].append(c)
|
68
|
+
|
69
|
+
sects = [s for s in sects if s]
|
70
|
+
|
71
|
+
# Restructure the tree, making the head of each group the parent
|
72
|
+
# of all the remaining nodes in that group
|
73
|
+
if len(sects) > 1:
|
74
|
+
block.children = []
|
75
|
+
for sect in sects:
|
76
|
+
if sect[0].node_type in sect_types:
|
77
|
+
sect_node = self.merge_nodes(sect)
|
78
|
+
sect_node.children = sect
|
79
|
+
sect_node.node_type = NodeType(str(sect[0].node_type)[:5])
|
80
|
+
block.children.append(sect_node)
|
81
|
+
else:
|
82
|
+
block.children.extend(sect)
|
83
|
+
|
84
|
+
# Push the children onto the queue
|
85
|
+
queue.extend(block.children)
|
86
|
+
|
87
|
+
return root
|
janus/language/block.py
CHANGED
@@ -152,9 +152,11 @@ class CodeBlock:
|
|
152
152
|
Returns:
|
153
153
|
A string representation of the tree with this block as the root
|
154
154
|
"""
|
155
|
+
tokens = self.tokens
|
155
156
|
identifier = self.id
|
156
157
|
if self.text is None:
|
157
158
|
identifier = f"({identifier})"
|
159
|
+
tokens = self.total_tokens
|
158
160
|
elif not self.complete:
|
159
161
|
identifier += "*"
|
160
162
|
if self.start_point is not None and self.end_point is not None:
|
@@ -165,7 +167,7 @@ class CodeBlock:
|
|
165
167
|
seg = ""
|
166
168
|
return "\n".join(
|
167
169
|
[
|
168
|
-
f"{'| '*depth}{identifier}{seg}",
|
170
|
+
f"{'| '*depth}{identifier}{seg} ({tokens:,d} tokens)",
|
169
171
|
*[c.tree_str(depth + 1) for c in self.children],
|
170
172
|
]
|
171
173
|
)
|
@@ -214,7 +216,7 @@ class TranslatedCodeBlock(CodeBlock):
|
|
214
216
|
self.translated = False
|
215
217
|
self.cost = 0.0
|
216
218
|
self.retries = 0
|
217
|
-
self.processing_time = 0
|
219
|
+
self.processing_time = 0.0
|
218
220
|
|
219
221
|
@property
|
220
222
|
def total_cost(self) -> float:
|
janus/language/combine.py
CHANGED
janus/language/mumps/mumps.py
CHANGED
@@ -48,6 +48,7 @@ class MumpsSplitter(Splitter):
|
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
protected_node_types: tuple[str] = ("routine_definition",),
|
50
50
|
prune_node_types: tuple[str] = (),
|
51
|
+
prune_unprotected: bool = False,
|
51
52
|
):
|
52
53
|
"""Initialize a MumpsSplitter instance.
|
53
54
|
|
@@ -60,11 +61,9 @@ class MumpsSplitter(Splitter):
|
|
60
61
|
max_tokens=max_tokens,
|
61
62
|
protected_node_types=protected_node_types,
|
62
63
|
prune_node_types=prune_node_types,
|
64
|
+
prune_unprotected=prune_unprotected,
|
63
65
|
)
|
64
66
|
|
65
|
-
# MUMPS code tends to take about 2/3 the space of Python
|
66
|
-
self.max_tokens: int = int(max_tokens * 2 / 5)
|
67
|
-
|
68
67
|
def _set_identifiers(self, root: CodeBlock, name: str):
|
69
68
|
stack = [root]
|
70
69
|
while stack:
|
janus/language/naive/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
1
|
+
from ..block import CodeBlock
|
2
|
+
from ..naive.chunk_splitter import ChunkSplitter
|
3
|
+
from ..naive.registry import register_splitter
|
4
|
+
from ..splitter import FileSizeError
|
5
5
|
|
6
6
|
|
7
7
|
@register_splitter("file")
|
@@ -1,7 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
1
|
+
from ..block import CodeBlock
|
2
|
+
from ..node import NodeType
|
3
|
+
from ..splitter import Splitter
|
4
|
+
from .registry import register_splitter
|
5
5
|
|
6
6
|
|
7
7
|
@register_splitter("chunk")
|