janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/converter/diagram.py
CHANGED
@@ -1,14 +1,6 @@
|
|
1
|
-
import
|
1
|
+
from langchain_core.runnables import Runnable, RunnableParallel
|
2
2
|
|
3
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
-
from langchain_core.exceptions import OutputParserException
|
5
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
6
|
-
|
7
|
-
from janus.converter.converter import run_if_changed
|
8
3
|
from janus.converter.document import Documenter
|
9
|
-
from janus.language.block import TranslatedCodeBlock
|
10
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
11
|
-
from janus.parsers.refiner_parser import RefinerParser
|
12
4
|
from janus.parsers.uml import UMLSyntaxParser
|
13
5
|
from janus.utils.logger import create_logger
|
14
6
|
|
@@ -16,10 +8,7 @@ log = create_logger(__name__)
|
|
16
8
|
|
17
9
|
|
18
10
|
class DiagramGenerator(Documenter):
|
19
|
-
"""
|
20
|
-
|
21
|
-
A class that translates code from one programming language to a set of diagrams.
|
22
|
-
"""
|
11
|
+
"""A Converter that translates code into a set of PLANTUML diagrams."""
|
23
12
|
|
24
13
|
def __init__(
|
25
14
|
self,
|
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
|
|
30
19
|
"""Initialize the DiagramGenerator class
|
31
20
|
|
32
21
|
Arguments:
|
33
|
-
model: The LLM to use for translation. If an OpenAI model, the
|
34
|
-
`OPENAI_API_KEY` environment variable must be set and the
|
35
|
-
`OPENAI_ORG_ID` environment variable should be set if needed.
|
36
|
-
model_arguments: Additional arguments to pass to the LLM constructor.
|
37
|
-
source_language: The source programming language.
|
38
|
-
max_prompts: The maximum number of prompts to try before giving up.
|
39
|
-
db_path: path to chroma database
|
40
|
-
db_config: database configuraiton
|
41
22
|
diagram_type: type of PLANTUML diagram to generate
|
23
|
+
add_documentation: Whether to add a documentation step prior to
|
24
|
+
diagram generation.
|
42
25
|
"""
|
43
|
-
super().__init__(**kwargs)
|
44
26
|
self._diagram_type = diagram_type
|
45
27
|
self._add_documentation = add_documentation
|
46
|
-
self._documenter =
|
47
|
-
self._diagram_parser = UMLSyntaxParser(language="plantuml")
|
48
|
-
if add_documentation:
|
49
|
-
self._diagram_prompt_template_name = "diagram_with_documentation"
|
50
|
-
else:
|
51
|
-
self._diagram_prompt_template_name = "diagram"
|
52
|
-
self._load_diagram_prompt_engine()
|
28
|
+
self._documenter = Documenter(**kwargs)
|
53
29
|
|
54
|
-
|
55
|
-
|
56
|
-
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
|
33
|
+
self._parser = UMLSyntaxParser(language="plantuml")
|
57
34
|
|
58
|
-
|
59
|
-
n2 = round((self.max_prompts // n1) ** (1 / 2))
|
35
|
+
self._load_parameters()
|
60
36
|
|
61
|
-
|
62
|
-
|
37
|
+
def _load_prompt(self):
|
38
|
+
super()._load_prompt()
|
39
|
+
self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
|
63
40
|
|
41
|
+
def _input_runnable(self) -> Runnable:
|
64
42
|
if self._add_documentation:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
**{
|
70
|
-
"SOURCE_CODE": input,
|
71
|
-
"DOCUMENTATION": documentation_text,
|
72
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
73
|
-
}
|
74
|
-
),
|
75
|
-
refiner=self._refiner,
|
76
|
-
max_retries=n1,
|
77
|
-
llm=self._llm,
|
78
|
-
)
|
79
|
-
else:
|
80
|
-
refine_output = RefinerParser(
|
81
|
-
parser=self._diagram_parser,
|
82
|
-
initial_prompt=self._diagram_prompt.format(
|
83
|
-
**{
|
84
|
-
"SOURCE_CODE": input,
|
85
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
86
|
-
}
|
87
|
-
),
|
88
|
-
refiner=self._refiner,
|
89
|
-
max_retries=n1,
|
90
|
-
llm=self._llm,
|
43
|
+
return RunnableParallel(
|
44
|
+
SOURCE_CODE=self._parser.parse_input,
|
45
|
+
DOCUMENTATION=self._documenter.chain,
|
46
|
+
context=self._retriever,
|
91
47
|
)
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
max_retries=n2,
|
96
|
-
)
|
97
|
-
completion_chain = self._prompt | self._llm
|
98
|
-
chain = RunnableParallel(
|
99
|
-
completion=completion_chain, prompt_value=self._diagram_prompt
|
100
|
-
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
101
|
-
for _ in range(n3):
|
102
|
-
try:
|
103
|
-
if self._add_documentation:
|
104
|
-
return chain.invoke(
|
105
|
-
{
|
106
|
-
"SOURCE_CODE": input,
|
107
|
-
"DOCUMENTATION": documentation_text,
|
108
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
else:
|
112
|
-
return chain.invoke(
|
113
|
-
{
|
114
|
-
"SOURCE_CODE": input,
|
115
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
116
|
-
}
|
117
|
-
)
|
118
|
-
except OutputParserException:
|
119
|
-
pass
|
120
|
-
|
121
|
-
raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
|
122
|
-
|
123
|
-
@run_if_changed(
|
124
|
-
"_diagram_prompt_template_name",
|
125
|
-
"_source_language",
|
126
|
-
)
|
127
|
-
def _load_diagram_prompt_engine(self) -> None:
|
128
|
-
"""Load the prompt engine according to this instance's attributes.
|
129
|
-
|
130
|
-
If the relevant fields have not been changed since the last time this method was
|
131
|
-
called, nothing happens.
|
132
|
-
"""
|
133
|
-
self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
|
134
|
-
source_language=self._source_language,
|
135
|
-
target_language="text",
|
136
|
-
target_version=None,
|
137
|
-
prompt_template=self._diagram_prompt_template_name,
|
48
|
+
return RunnableParallel(
|
49
|
+
SOURCE_CODE=self._parser.parse_input,
|
50
|
+
context=self._retriever,
|
138
51
|
)
|
139
|
-
self._diagram_prompt = self._diagram_prompt_engine.prompt
|
janus/converter/translate.py
CHANGED
@@ -90,7 +90,7 @@ class Translator(Converter):
|
|
90
90
|
f"({self._source_language} != {self._target_language})"
|
91
91
|
)
|
92
92
|
|
93
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
93
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
94
94
|
source_language=self._source_language,
|
95
95
|
target_language=self._target_language,
|
96
96
|
target_version=self._target_version,
|
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = AlcSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="ibmhlasm")
|
18
18
|
self.test_file = Path("janus/language/alc/_tests/alc.asm")
|
janus/language/alc/alc.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
import re
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from langchain.schema.language_model import BaseLanguageModel
|
5
|
-
|
6
4
|
from janus.language.block import CodeBlock
|
7
5
|
from janus.language.combine import Combiner
|
8
6
|
from janus.language.node import NodeType
|
9
7
|
from janus.language.treesitter import TreeSitterSplitter
|
8
|
+
from janus.llm.models_info import JanusModel
|
10
9
|
from janus.utils.logger import create_logger
|
11
10
|
|
12
11
|
log = create_logger(__name__)
|
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
|
|
27
26
|
|
28
27
|
def __init__(
|
29
28
|
self,
|
30
|
-
model:
|
29
|
+
model: JanusModel | None = None,
|
31
30
|
max_tokens: int = 4096,
|
32
31
|
protected_node_types: tuple[str, ...] = (),
|
33
32
|
prune_node_types: tuple[str, ...] = (),
|
@@ -63,7 +62,7 @@ class AlcSplitter(TreeSitterSplitter):
|
|
63
62
|
# instruction and containing all the subsequent nodes up until the
|
64
63
|
# next csect or dsect instruction
|
65
64
|
sects: list[list[CodeBlock]] = [[]]
|
66
|
-
for c in block.children:
|
65
|
+
for c in sorted(block.children):
|
67
66
|
if c.node_type == "csect_instruction":
|
68
67
|
c.context_tags["alc_section"] = "CSECT"
|
69
68
|
sects.append([c])
|
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
101
100
|
|
102
101
|
def __init__(
|
103
102
|
self,
|
104
|
-
model:
|
103
|
+
model: JanusModel | None = None,
|
105
104
|
max_tokens: int = 4096,
|
106
105
|
protected_node_types: tuple[str, ...] = (),
|
107
106
|
prune_node_types: tuple[str, ...] = (),
|
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
|
|
129
128
|
prune_unprotected=prune_unprotected,
|
130
129
|
)
|
131
130
|
|
132
|
-
def
|
131
|
+
def split_string(self, code: str, name: str) -> CodeBlock:
|
132
|
+
# Override split_string to use processed code and track active usings
|
133
133
|
active_usings = self.get_active_usings(code)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
134
|
+
processed_code = self.preproccess_assembly(code)
|
135
|
+
root = super().split_string(processed_code, name)
|
136
|
+
if active_usings is not None:
|
137
|
+
stack = [root]
|
138
|
+
while stack:
|
139
|
+
block = stack.pop()
|
140
|
+
block.context_tags["active_usings"] = active_usings
|
141
|
+
stack.extend(block.children)
|
142
|
+
return root
|
138
143
|
|
139
144
|
def preproccess_assembly(self, code: str) -> str:
|
140
145
|
"""Remove non-essential lines from an assembly snippet"""
|
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
142
147
|
lines = code.splitlines()
|
143
148
|
lines = self.strip_header_and_left(lines)
|
144
149
|
lines = self.strip_addresses(lines)
|
145
|
-
return "".join(str(line) for line in lines)
|
150
|
+
return "\n".join(str(line) for line in lines)
|
146
151
|
|
147
152
|
def get_active_usings(self, code: str) -> Optional[str]:
|
148
153
|
"""Look for 'active usings' in the ALC listing header"""
|
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
|
|
15
15
|
def setUp(self):
|
16
16
|
model_name = "gpt-4o"
|
17
17
|
self.binary_file = Path("janus/language/binary/_tests/hello")
|
18
|
-
self.llm
|
18
|
+
self.llm = load_model(model_name)
|
19
19
|
self.splitter = BinarySplitter(model=self.llm)
|
20
20
|
os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
|
21
21
|
|
janus/language/binary/binary.py
CHANGED
@@ -5,11 +5,11 @@ import tempfile
|
|
5
5
|
from pathlib import Path
|
6
6
|
|
7
7
|
import tree_sitter
|
8
|
-
from langchain.schema.language_model import BaseLanguageModel
|
9
8
|
|
10
9
|
from janus.language.block import CodeBlock
|
11
10
|
from janus.language.combine import Combiner
|
12
11
|
from janus.language.treesitter import TreeSitterSplitter
|
12
|
+
from janus.llm.models_info import JanusModel
|
13
13
|
from janus.utils.enums import LANGUAGES
|
14
14
|
from janus.utils.logger import create_logger
|
15
15
|
|
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
|
|
31
31
|
|
32
32
|
def __init__(
|
33
33
|
self,
|
34
|
-
model:
|
34
|
+
model: JanusModel | None = None,
|
35
35
|
max_tokens: int = 4096,
|
36
36
|
protected_node_types: tuple[str] = (),
|
37
37
|
prune_node_types: tuple[str] = (),
|
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = MumpsSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="mumps")
|
18
18
|
self.test_file = Path("janus/language/mumps/_tests/mumps.m")
|
janus/language/mumps/mumps.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
import re
|
2
2
|
|
3
|
-
from langchain.schema.language_model import BaseLanguageModel
|
4
|
-
|
5
3
|
from janus.language.block import CodeBlock
|
6
4
|
from janus.language.combine import Combiner
|
7
5
|
from janus.language.node import NodeType
|
8
6
|
from janus.language.splitter import Splitter
|
7
|
+
from janus.llm.models_info import JanusModel
|
9
8
|
from janus.utils.logger import create_logger
|
10
9
|
|
11
10
|
log = create_logger(__name__)
|
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
|
|
44
43
|
|
45
44
|
def __init__(
|
46
45
|
self,
|
47
|
-
model:
|
46
|
+
model: JanusModel | None = None,
|
48
47
|
max_tokens: int = 4096,
|
49
48
|
protected_node_types: tuple[str] = ("routine_definition",),
|
50
49
|
prune_node_types: tuple[str] = (),
|
@@ -19,6 +19,7 @@ def get_flexible_ast(language: str, **kwargs) -> Splitter:
|
|
19
19
|
Returns:
|
20
20
|
A flexible AST splitter for the given language.
|
21
21
|
"""
|
22
|
+
kwargs.update(protected_node_types=())
|
22
23
|
if language == "ibmhlasm":
|
23
24
|
return AlcSplitter(**kwargs)
|
24
25
|
elif language == "mumps":
|
@@ -28,7 +29,7 @@ def get_flexible_ast(language: str, **kwargs) -> Splitter:
|
|
28
29
|
|
29
30
|
|
30
31
|
@register_splitter("ast-strict")
|
31
|
-
def get_strict_ast(language: str, **kwargs) -> Splitter:
|
32
|
+
def get_strict_ast(language: str, prune_unprotected=True, **kwargs) -> Splitter:
|
32
33
|
"""Get a strict AST splitter for the given language.
|
33
34
|
|
34
35
|
The strict splitter will only return nodes that are of a functional type.
|
@@ -41,7 +42,7 @@ def get_strict_ast(language: str, **kwargs) -> Splitter:
|
|
41
42
|
"""
|
42
43
|
kwargs.update(
|
43
44
|
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
44
|
-
prune_unprotected=
|
45
|
+
prune_unprotected=prune_unprotected,
|
45
46
|
)
|
46
47
|
if language == "ibmhlasm":
|
47
48
|
return AlcSplitter(**kwargs)
|
janus/language/splitter.py
CHANGED
@@ -4,11 +4,11 @@ from pathlib import Path
|
|
4
4
|
from typing import List
|
5
5
|
|
6
6
|
import tiktoken
|
7
|
-
from langchain.schema.language_model import BaseLanguageModel
|
8
7
|
|
9
8
|
from janus.language.block import CodeBlock
|
10
9
|
from janus.language.file import FileManager
|
11
10
|
from janus.language.node import NodeType
|
11
|
+
from janus.llm.models_info import JanusModel
|
12
12
|
from janus.utils.logger import create_logger
|
13
13
|
|
14
14
|
log = create_logger(__name__)
|
@@ -44,7 +44,7 @@ class Splitter(FileManager):
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
46
|
language: str,
|
47
|
-
model:
|
47
|
+
model: JanusModel | None = None,
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
skip_merge: bool = False,
|
50
50
|
protected_node_types: tuple[str, ...] = (),
|
@@ -387,7 +387,10 @@ class Splitter(FileManager):
|
|
387
387
|
return
|
388
388
|
|
389
389
|
if self._is_protected(node):
|
390
|
-
|
390
|
+
log.error(
|
391
|
+
"Protected node too large for context!"
|
392
|
+
f" ({node.tokens} > {self.max_tokens})"
|
393
|
+
)
|
391
394
|
|
392
395
|
if node.children:
|
393
396
|
for child in node.children:
|
@@ -423,7 +426,7 @@ class Splitter(FileManager):
|
|
423
426
|
name = f"{node.name}-L#{node_line}"
|
424
427
|
tokens = self._count_tokens(line)
|
425
428
|
if tokens > self.max_tokens:
|
426
|
-
|
429
|
+
log.error(
|
427
430
|
"Irreducible node too large for context!"
|
428
431
|
f" ({tokens} > {self.max_tokens})"
|
429
432
|
)
|
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
|
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
15
|
self.maxDiff = None
|
16
|
-
self.llm
|
16
|
+
self.llm = load_model(model_name)
|
17
17
|
|
18
18
|
def _split(self):
|
19
19
|
"""Split the test file."""
|
@@ -7,10 +7,10 @@ from typing import Optional
|
|
7
7
|
|
8
8
|
import tree_sitter
|
9
9
|
from git import Repo
|
10
|
-
from langchain.schema.language_model import BaseLanguageModel
|
11
10
|
|
12
11
|
from janus.language.block import CodeBlock, NodeType
|
13
12
|
from janus.language.splitter import Splitter
|
13
|
+
from janus.llm.models_info import JanusModel
|
14
14
|
from janus.utils.enums import LANGUAGES
|
15
15
|
from janus.utils.logger import create_logger
|
16
16
|
|
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
|
|
25
25
|
def __init__(
|
26
26
|
self,
|
27
27
|
language: str,
|
28
|
-
model:
|
28
|
+
model: JanusModel | None = None,
|
29
29
|
max_tokens: int = 4096,
|
30
30
|
protected_node_types: tuple[str, ...] = (),
|
31
31
|
prune_node_types: tuple[str, ...] = (),
|
janus/llm/model_callbacks.py
CHANGED
@@ -12,6 +12,17 @@ from janus.utils.logger import create_logger
|
|
12
12
|
|
13
13
|
log = create_logger(__name__)
|
14
14
|
|
15
|
+
openai_model_reroutes = {
|
16
|
+
"gpt-4o": "gpt-4o-2024-05-13",
|
17
|
+
"gpt-4o-mini": "gpt-4o-mini",
|
18
|
+
"gpt-4": "gpt-4-0613",
|
19
|
+
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
20
|
+
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
21
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
22
|
+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
23
|
+
"gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
|
24
|
+
}
|
25
|
+
|
15
26
|
|
16
27
|
# Updated 2024-06-21
|
17
28
|
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
@@ -45,6 +56,8 @@ def _get_token_cost(
|
|
45
56
|
prompt_tokens: int, completion_tokens: int, model_id: str | None
|
46
57
|
) -> float:
|
47
58
|
"""Get the cost of tokens according to model ID"""
|
59
|
+
if model_id in openai_model_reroutes:
|
60
|
+
model_id = openai_model_reroutes[model_id]
|
48
61
|
if model_id not in COST_PER_1K_TOKENS:
|
49
62
|
raise ValueError(
|
50
63
|
f"Unknown model: {model_id}. Please provide a valid model name."
|