janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/converter/diagram.py
CHANGED
@@ -1,14 +1,6 @@
|
|
1
|
-
import
|
1
|
+
from langchain_core.runnables import Runnable, RunnableParallel
|
2
2
|
|
3
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
-
from langchain_core.exceptions import OutputParserException
|
5
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
6
|
-
|
7
|
-
from janus.converter.converter import run_if_changed
|
8
3
|
from janus.converter.document import Documenter
|
9
|
-
from janus.language.block import TranslatedCodeBlock
|
10
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
11
|
-
from janus.parsers.refiner_parser import RefinerParser
|
12
4
|
from janus.parsers.uml import UMLSyntaxParser
|
13
5
|
from janus.utils.logger import create_logger
|
14
6
|
|
@@ -16,10 +8,7 @@ log = create_logger(__name__)
|
|
16
8
|
|
17
9
|
|
18
10
|
class DiagramGenerator(Documenter):
|
19
|
-
"""
|
20
|
-
|
21
|
-
A class that translates code from one programming language to a set of diagrams.
|
22
|
-
"""
|
11
|
+
"""A Converter that translates code into a set of PLANTUML diagrams."""
|
23
12
|
|
24
13
|
def __init__(
|
25
14
|
self,
|
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
|
|
30
19
|
"""Initialize the DiagramGenerator class
|
31
20
|
|
32
21
|
Arguments:
|
33
|
-
model: The LLM to use for translation. If an OpenAI model, the
|
34
|
-
`OPENAI_API_KEY` environment variable must be set and the
|
35
|
-
`OPENAI_ORG_ID` environment variable should be set if needed.
|
36
|
-
model_arguments: Additional arguments to pass to the LLM constructor.
|
37
|
-
source_language: The source programming language.
|
38
|
-
max_prompts: The maximum number of prompts to try before giving up.
|
39
|
-
db_path: path to chroma database
|
40
|
-
db_config: database configuraiton
|
41
22
|
diagram_type: type of PLANTUML diagram to generate
|
23
|
+
add_documentation: Whether to add a documentation step prior to
|
24
|
+
diagram generation.
|
42
25
|
"""
|
43
|
-
super().__init__(**kwargs)
|
44
26
|
self._diagram_type = diagram_type
|
45
27
|
self._add_documentation = add_documentation
|
46
|
-
self._documenter =
|
47
|
-
self._diagram_parser = UMLSyntaxParser(language="plantuml")
|
48
|
-
if add_documentation:
|
49
|
-
self._diagram_prompt_template_name = "diagram_with_documentation"
|
50
|
-
else:
|
51
|
-
self._diagram_prompt_template_name = "diagram"
|
52
|
-
self._load_diagram_prompt_engine()
|
28
|
+
self._documenter = Documenter(**kwargs)
|
53
29
|
|
54
|
-
|
55
|
-
|
56
|
-
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
|
33
|
+
self._parser = UMLSyntaxParser(language="plantuml")
|
57
34
|
|
58
|
-
|
59
|
-
n2 = round((self.max_prompts // n1) ** (1 / 2))
|
35
|
+
self._load_parameters()
|
60
36
|
|
61
|
-
|
62
|
-
|
37
|
+
def _load_prompt(self):
|
38
|
+
super()._load_prompt()
|
39
|
+
self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
|
63
40
|
|
41
|
+
def _input_runnable(self) -> Runnable:
|
64
42
|
if self._add_documentation:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
**{
|
70
|
-
"SOURCE_CODE": input,
|
71
|
-
"DOCUMENTATION": documentation_text,
|
72
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
73
|
-
}
|
74
|
-
),
|
75
|
-
refiner=self._refiner,
|
76
|
-
max_retries=n1,
|
77
|
-
llm=self._llm,
|
78
|
-
)
|
79
|
-
else:
|
80
|
-
refine_output = RefinerParser(
|
81
|
-
parser=self._diagram_parser,
|
82
|
-
initial_prompt=self._diagram_prompt.format(
|
83
|
-
**{
|
84
|
-
"SOURCE_CODE": input,
|
85
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
86
|
-
}
|
87
|
-
),
|
88
|
-
refiner=self._refiner,
|
89
|
-
max_retries=n1,
|
90
|
-
llm=self._llm,
|
43
|
+
return RunnableParallel(
|
44
|
+
SOURCE_CODE=self._parser.parse_input,
|
45
|
+
DOCUMENTATION=self._documenter.chain,
|
46
|
+
context=self._retriever,
|
91
47
|
)
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
max_retries=n2,
|
96
|
-
)
|
97
|
-
completion_chain = self._prompt | self._llm
|
98
|
-
chain = RunnableParallel(
|
99
|
-
completion=completion_chain, prompt_value=self._diagram_prompt
|
100
|
-
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
101
|
-
for _ in range(n3):
|
102
|
-
try:
|
103
|
-
if self._add_documentation:
|
104
|
-
return chain.invoke(
|
105
|
-
{
|
106
|
-
"SOURCE_CODE": input,
|
107
|
-
"DOCUMENTATION": documentation_text,
|
108
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
else:
|
112
|
-
return chain.invoke(
|
113
|
-
{
|
114
|
-
"SOURCE_CODE": input,
|
115
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
116
|
-
}
|
117
|
-
)
|
118
|
-
except OutputParserException:
|
119
|
-
pass
|
120
|
-
|
121
|
-
raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
|
122
|
-
|
123
|
-
@run_if_changed(
|
124
|
-
"_diagram_prompt_template_name",
|
125
|
-
"_source_language",
|
126
|
-
)
|
127
|
-
def _load_diagram_prompt_engine(self) -> None:
|
128
|
-
"""Load the prompt engine according to this instance's attributes.
|
129
|
-
|
130
|
-
If the relevant fields have not been changed since the last time this method was
|
131
|
-
called, nothing happens.
|
132
|
-
"""
|
133
|
-
self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
|
134
|
-
source_language=self._source_language,
|
135
|
-
target_language="text",
|
136
|
-
target_version=None,
|
137
|
-
prompt_template=self._diagram_prompt_template_name,
|
48
|
+
return RunnableParallel(
|
49
|
+
SOURCE_CODE=self._parser.parse_input,
|
50
|
+
context=self._retriever,
|
138
51
|
)
|
139
|
-
self._diagram_prompt = self._diagram_prompt_engine.prompt
|
janus/converter/translate.py
CHANGED
@@ -90,7 +90,7 @@ class Translator(Converter):
|
|
90
90
|
f"({self._source_language} != {self._target_language})"
|
91
91
|
)
|
92
92
|
|
93
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
93
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
94
94
|
source_language=self._source_language,
|
95
95
|
target_language=self._target_language,
|
96
96
|
target_version=self._target_version,
|
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = AlcSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="ibmhlasm")
|
18
18
|
self.test_file = Path("janus/language/alc/_tests/alc.asm")
|
janus/language/alc/alc.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
import re
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from langchain.schema.language_model import BaseLanguageModel
|
5
|
-
|
6
4
|
from janus.language.block import CodeBlock
|
7
5
|
from janus.language.combine import Combiner
|
8
6
|
from janus.language.node import NodeType
|
9
7
|
from janus.language.treesitter import TreeSitterSplitter
|
8
|
+
from janus.llm.models_info import JanusModel
|
10
9
|
from janus.utils.logger import create_logger
|
11
10
|
|
12
11
|
log = create_logger(__name__)
|
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
|
|
27
26
|
|
28
27
|
def __init__(
|
29
28
|
self,
|
30
|
-
model:
|
29
|
+
model: JanusModel | None = None,
|
31
30
|
max_tokens: int = 4096,
|
32
31
|
protected_node_types: tuple[str, ...] = (),
|
33
32
|
prune_node_types: tuple[str, ...] = (),
|
@@ -63,7 +62,7 @@ class AlcSplitter(TreeSitterSplitter):
|
|
63
62
|
# instruction and containing all the subsequent nodes up until the
|
64
63
|
# next csect or dsect instruction
|
65
64
|
sects: list[list[CodeBlock]] = [[]]
|
66
|
-
for c in block.children:
|
65
|
+
for c in sorted(block.children):
|
67
66
|
if c.node_type == "csect_instruction":
|
68
67
|
c.context_tags["alc_section"] = "CSECT"
|
69
68
|
sects.append([c])
|
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
101
100
|
|
102
101
|
def __init__(
|
103
102
|
self,
|
104
|
-
model:
|
103
|
+
model: JanusModel | None = None,
|
105
104
|
max_tokens: int = 4096,
|
106
105
|
protected_node_types: tuple[str, ...] = (),
|
107
106
|
prune_node_types: tuple[str, ...] = (),
|
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
|
|
129
128
|
prune_unprotected=prune_unprotected,
|
130
129
|
)
|
131
130
|
|
132
|
-
def
|
131
|
+
def split_string(self, code: str, name: str) -> CodeBlock:
|
132
|
+
# Override split_string to use processed code and track active usings
|
133
133
|
active_usings = self.get_active_usings(code)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
134
|
+
processed_code = self.preproccess_assembly(code)
|
135
|
+
root = super().split_string(processed_code, name)
|
136
|
+
if active_usings is not None:
|
137
|
+
stack = [root]
|
138
|
+
while stack:
|
139
|
+
block = stack.pop()
|
140
|
+
block.context_tags["active_usings"] = active_usings
|
141
|
+
stack.extend(block.children)
|
142
|
+
return root
|
138
143
|
|
139
144
|
def preproccess_assembly(self, code: str) -> str:
|
140
145
|
"""Remove non-essential lines from an assembly snippet"""
|
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
142
147
|
lines = code.splitlines()
|
143
148
|
lines = self.strip_header_and_left(lines)
|
144
149
|
lines = self.strip_addresses(lines)
|
145
|
-
return "".join(str(line) for line in lines)
|
150
|
+
return "\n".join(str(line) for line in lines)
|
146
151
|
|
147
152
|
def get_active_usings(self, code: str) -> Optional[str]:
|
148
153
|
"""Look for 'active usings' in the ALC listing header"""
|
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
|
|
15
15
|
def setUp(self):
|
16
16
|
model_name = "gpt-4o"
|
17
17
|
self.binary_file = Path("janus/language/binary/_tests/hello")
|
18
|
-
self.llm
|
18
|
+
self.llm = load_model(model_name)
|
19
19
|
self.splitter = BinarySplitter(model=self.llm)
|
20
20
|
os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
|
21
21
|
|
janus/language/binary/binary.py
CHANGED
@@ -5,11 +5,11 @@ import tempfile
|
|
5
5
|
from pathlib import Path
|
6
6
|
|
7
7
|
import tree_sitter
|
8
|
-
from langchain.schema.language_model import BaseLanguageModel
|
9
8
|
|
10
9
|
from janus.language.block import CodeBlock
|
11
10
|
from janus.language.combine import Combiner
|
12
11
|
from janus.language.treesitter import TreeSitterSplitter
|
12
|
+
from janus.llm.models_info import JanusModel
|
13
13
|
from janus.utils.enums import LANGUAGES
|
14
14
|
from janus.utils.logger import create_logger
|
15
15
|
|
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
|
|
31
31
|
|
32
32
|
def __init__(
|
33
33
|
self,
|
34
|
-
model:
|
34
|
+
model: JanusModel | None = None,
|
35
35
|
max_tokens: int = 4096,
|
36
36
|
protected_node_types: tuple[str] = (),
|
37
37
|
prune_node_types: tuple[str] = (),
|
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = MumpsSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="mumps")
|
18
18
|
self.test_file = Path("janus/language/mumps/_tests/mumps.m")
|
janus/language/mumps/mumps.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
import re
|
2
2
|
|
3
|
-
from langchain.schema.language_model import BaseLanguageModel
|
4
|
-
|
5
3
|
from janus.language.block import CodeBlock
|
6
4
|
from janus.language.combine import Combiner
|
7
5
|
from janus.language.node import NodeType
|
8
6
|
from janus.language.splitter import Splitter
|
7
|
+
from janus.llm.models_info import JanusModel
|
9
8
|
from janus.utils.logger import create_logger
|
10
9
|
|
11
10
|
log = create_logger(__name__)
|
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
|
|
44
43
|
|
45
44
|
def __init__(
|
46
45
|
self,
|
47
|
-
model:
|
46
|
+
model: JanusModel | None = None,
|
48
47
|
max_tokens: int = 4096,
|
49
48
|
protected_node_types: tuple[str] = ("routine_definition",),
|
50
49
|
prune_node_types: tuple[str] = (),
|
@@ -19,6 +19,7 @@ def get_flexible_ast(language: str, **kwargs) -> Splitter:
|
|
19
19
|
Returns:
|
20
20
|
A flexible AST splitter for the given language.
|
21
21
|
"""
|
22
|
+
kwargs.update(protected_node_types=())
|
22
23
|
if language == "ibmhlasm":
|
23
24
|
return AlcSplitter(**kwargs)
|
24
25
|
elif language == "mumps":
|
@@ -28,7 +29,7 @@ def get_flexible_ast(language: str, **kwargs) -> Splitter:
|
|
28
29
|
|
29
30
|
|
30
31
|
@register_splitter("ast-strict")
|
31
|
-
def get_strict_ast(language: str, **kwargs) -> Splitter:
|
32
|
+
def get_strict_ast(language: str, prune_unprotected=True, **kwargs) -> Splitter:
|
32
33
|
"""Get a strict AST splitter for the given language.
|
33
34
|
|
34
35
|
The strict splitter will only return nodes that are of a functional type.
|
@@ -41,7 +42,7 @@ def get_strict_ast(language: str, **kwargs) -> Splitter:
|
|
41
42
|
"""
|
42
43
|
kwargs.update(
|
43
44
|
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
44
|
-
prune_unprotected=
|
45
|
+
prune_unprotected=prune_unprotected,
|
45
46
|
)
|
46
47
|
if language == "ibmhlasm":
|
47
48
|
return AlcSplitter(**kwargs)
|
janus/language/splitter.py
CHANGED
@@ -4,11 +4,11 @@ from pathlib import Path
|
|
4
4
|
from typing import List
|
5
5
|
|
6
6
|
import tiktoken
|
7
|
-
from langchain.schema.language_model import BaseLanguageModel
|
8
7
|
|
9
8
|
from janus.language.block import CodeBlock
|
10
9
|
from janus.language.file import FileManager
|
11
10
|
from janus.language.node import NodeType
|
11
|
+
from janus.llm.models_info import JanusModel
|
12
12
|
from janus.utils.logger import create_logger
|
13
13
|
|
14
14
|
log = create_logger(__name__)
|
@@ -44,7 +44,7 @@ class Splitter(FileManager):
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
46
|
language: str,
|
47
|
-
model:
|
47
|
+
model: JanusModel | None = None,
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
skip_merge: bool = False,
|
50
50
|
protected_node_types: tuple[str, ...] = (),
|
@@ -387,7 +387,10 @@ class Splitter(FileManager):
|
|
387
387
|
return
|
388
388
|
|
389
389
|
if self._is_protected(node):
|
390
|
-
|
390
|
+
log.error(
|
391
|
+
"Protected node too large for context!"
|
392
|
+
f" ({node.tokens} > {self.max_tokens})"
|
393
|
+
)
|
391
394
|
|
392
395
|
if node.children:
|
393
396
|
for child in node.children:
|
@@ -423,7 +426,7 @@ class Splitter(FileManager):
|
|
423
426
|
name = f"{node.name}-L#{node_line}"
|
424
427
|
tokens = self._count_tokens(line)
|
425
428
|
if tokens > self.max_tokens:
|
426
|
-
|
429
|
+
log.error(
|
427
430
|
"Irreducible node too large for context!"
|
428
431
|
f" ({tokens} > {self.max_tokens})"
|
429
432
|
)
|
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
|
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
15
|
self.maxDiff = None
|
16
|
-
self.llm
|
16
|
+
self.llm = load_model(model_name)
|
17
17
|
|
18
18
|
def _split(self):
|
19
19
|
"""Split the test file."""
|
@@ -7,10 +7,10 @@ from typing import Optional
|
|
7
7
|
|
8
8
|
import tree_sitter
|
9
9
|
from git import Repo
|
10
|
-
from langchain.schema.language_model import BaseLanguageModel
|
11
10
|
|
12
11
|
from janus.language.block import CodeBlock, NodeType
|
13
12
|
from janus.language.splitter import Splitter
|
13
|
+
from janus.llm.models_info import JanusModel
|
14
14
|
from janus.utils.enums import LANGUAGES
|
15
15
|
from janus.utils.logger import create_logger
|
16
16
|
|
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
|
|
25
25
|
def __init__(
|
26
26
|
self,
|
27
27
|
language: str,
|
28
|
-
model:
|
28
|
+
model: JanusModel | None = None,
|
29
29
|
max_tokens: int = 4096,
|
30
30
|
protected_node_types: tuple[str, ...] = (),
|
31
31
|
prune_node_types: tuple[str, ...] = (),
|
janus/llm/model_callbacks.py
CHANGED
@@ -12,6 +12,17 @@ from janus.utils.logger import create_logger
|
|
12
12
|
|
13
13
|
log = create_logger(__name__)
|
14
14
|
|
15
|
+
openai_model_reroutes = {
|
16
|
+
"gpt-4o": "gpt-4o-2024-05-13",
|
17
|
+
"gpt-4o-mini": "gpt-4o-mini",
|
18
|
+
"gpt-4": "gpt-4-0613",
|
19
|
+
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
20
|
+
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
21
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
22
|
+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
23
|
+
"gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
|
24
|
+
}
|
25
|
+
|
15
26
|
|
16
27
|
# Updated 2024-06-21
|
17
28
|
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
@@ -45,6 +56,8 @@ def _get_token_cost(
|
|
45
56
|
prompt_tokens: int, completion_tokens: int, model_id: str | None
|
46
57
|
) -> float:
|
47
58
|
"""Get the cost of tokens according to model ID"""
|
59
|
+
if model_id in openai_model_reroutes:
|
60
|
+
model_id = openai_model_reroutes[model_id]
|
48
61
|
if model_id not in COST_PER_1K_TOKENS:
|
49
62
|
raise ValueError(
|
50
63
|
f"Unknown model: {model_id}. Please provide a valid model name."
|