janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,14 +1,6 @@
1
- import math
1
+ from langchain_core.runnables import Runnable, RunnableParallel
2
2
 
3
- from langchain.output_parsers import RetryWithErrorOutputParser
4
- from langchain_core.exceptions import OutputParserException
5
- from langchain_core.runnables import RunnableLambda, RunnableParallel
6
-
7
- from janus.converter.converter import run_if_changed
8
3
  from janus.converter.document import Documenter
9
- from janus.language.block import TranslatedCodeBlock
10
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
11
- from janus.parsers.refiner_parser import RefinerParser
12
4
  from janus.parsers.uml import UMLSyntaxParser
13
5
  from janus.utils.logger import create_logger
14
6
 
@@ -16,10 +8,7 @@ log = create_logger(__name__)
16
8
 
17
9
 
18
10
  class DiagramGenerator(Documenter):
19
- """DiagramGenerator
20
-
21
- A class that translates code from one programming language to a set of diagrams.
22
- """
11
+ """A Converter that translates code into a set of PLANTUML diagrams."""
23
12
 
24
13
  def __init__(
25
14
  self,
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
30
19
  """Initialize the DiagramGenerator class
31
20
 
32
21
  Arguments:
33
- model: The LLM to use for translation. If an OpenAI model, the
34
- `OPENAI_API_KEY` environment variable must be set and the
35
- `OPENAI_ORG_ID` environment variable should be set if needed.
36
- model_arguments: Additional arguments to pass to the LLM constructor.
37
- source_language: The source programming language.
38
- max_prompts: The maximum number of prompts to try before giving up.
39
- db_path: path to chroma database
40
- db_config: database configuraiton
41
22
  diagram_type: type of PLANTUML diagram to generate
23
+ add_documentation: Whether to add a documentation step prior to
24
+ diagram generation.
42
25
  """
43
- super().__init__(**kwargs)
44
26
  self._diagram_type = diagram_type
45
27
  self._add_documentation = add_documentation
46
- self._documenter = None
47
- self._diagram_parser = UMLSyntaxParser(language="plantuml")
48
- if add_documentation:
49
- self._diagram_prompt_template_name = "diagram_with_documentation"
50
- else:
51
- self._diagram_prompt_template_name = "diagram"
52
- self._load_diagram_prompt_engine()
28
+ self._documenter = Documenter(**kwargs)
53
29
 
54
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
55
- input = self._parser.parse_input(block.original)
56
- n1 = round(self.max_prompts ** (1 / 3))
30
+ super().__init__(**kwargs)
31
+
32
+ self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
33
+ self._parser = UMLSyntaxParser(language="plantuml")
57
34
 
58
- # Retries with the input, output, and error
59
- n2 = round((self.max_prompts // n1) ** (1 / 2))
35
+ self._load_parameters()
60
36
 
61
- # Retries with just the input
62
- n3 = math.ceil(self.max_prompts / (n1 * n2))
37
+ def _load_prompt(self):
38
+ super()._load_prompt()
39
+ self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
63
40
 
41
+ def _input_runnable(self) -> Runnable:
64
42
  if self._add_documentation:
65
- documentation_text = super()._run_chain(block)
66
- refine_output = RefinerParser(
67
- parser=self._diagram_parser,
68
- initial_prompt=self._diagram_prompt.format(
69
- **{
70
- "SOURCE_CODE": input,
71
- "DOCUMENTATION": documentation_text,
72
- "DIAGRAM_TYPE": self._diagram_type,
73
- }
74
- ),
75
- refiner=self._refiner,
76
- max_retries=n1,
77
- llm=self._llm,
78
- )
79
- else:
80
- refine_output = RefinerParser(
81
- parser=self._diagram_parser,
82
- initial_prompt=self._diagram_prompt.format(
83
- **{
84
- "SOURCE_CODE": input,
85
- "DIAGRAM_TYPE": self._diagram_type,
86
- }
87
- ),
88
- refiner=self._refiner,
89
- max_retries=n1,
90
- llm=self._llm,
43
+ return RunnableParallel(
44
+ SOURCE_CODE=self._parser.parse_input,
45
+ DOCUMENTATION=self._documenter.chain,
46
+ context=self._retriever,
91
47
  )
92
- retry = RetryWithErrorOutputParser.from_llm(
93
- llm=self._llm,
94
- parser=refine_output,
95
- max_retries=n2,
96
- )
97
- completion_chain = self._prompt | self._llm
98
- chain = RunnableParallel(
99
- completion=completion_chain, prompt_value=self._diagram_prompt
100
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
101
- for _ in range(n3):
102
- try:
103
- if self._add_documentation:
104
- return chain.invoke(
105
- {
106
- "SOURCE_CODE": input,
107
- "DOCUMENTATION": documentation_text,
108
- "DIAGRAM_TYPE": self._diagram_type,
109
- }
110
- )
111
- else:
112
- return chain.invoke(
113
- {
114
- "SOURCE_CODE": input,
115
- "DIAGRAM_TYPE": self._diagram_type,
116
- }
117
- )
118
- except OutputParserException:
119
- pass
120
-
121
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
122
-
123
- @run_if_changed(
124
- "_diagram_prompt_template_name",
125
- "_source_language",
126
- )
127
- def _load_diagram_prompt_engine(self) -> None:
128
- """Load the prompt engine according to this instance's attributes.
129
-
130
- If the relevant fields have not been changed since the last time this method was
131
- called, nothing happens.
132
- """
133
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
134
- source_language=self._source_language,
135
- target_language="text",
136
- target_version=None,
137
- prompt_template=self._diagram_prompt_template_name,
48
+ return RunnableParallel(
49
+ SOURCE_CODE=self._parser.parse_input,
50
+ context=self._retriever,
138
51
  )
139
- self._diagram_prompt = self._diagram_prompt_engine.prompt
@@ -90,7 +90,7 @@ class Translator(Converter):
90
90
  f"({self._source_language} != {self._target_language})"
91
91
  )
92
92
 
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
93
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
94
94
  source_language=self._source_language,
95
95
  target_language=self._target_language,
96
96
  target_version=self._target_version,
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = AlcSplitter(model=llm)
17
17
  self.combiner = Combiner(language="ibmhlasm")
18
18
  self.test_file = Path("janus/language/alc/_tests/alc.asm")
janus/language/alc/alc.py CHANGED
@@ -1,12 +1,11 @@
1
1
  import re
2
2
  from typing import Optional
3
3
 
4
- from langchain.schema.language_model import BaseLanguageModel
5
-
6
4
  from janus.language.block import CodeBlock
7
5
  from janus.language.combine import Combiner
8
6
  from janus.language.node import NodeType
9
7
  from janus.language.treesitter import TreeSitterSplitter
8
+ from janus.llm.models_info import JanusModel
10
9
  from janus.utils.logger import create_logger
11
10
 
12
11
  log = create_logger(__name__)
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
27
26
 
28
27
  def __init__(
29
28
  self,
30
- model: None | BaseLanguageModel = None,
29
+ model: JanusModel | None = None,
31
30
  max_tokens: int = 4096,
32
31
  protected_node_types: tuple[str, ...] = (),
33
32
  prune_node_types: tuple[str, ...] = (),
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
101
100
 
102
101
  def __init__(
103
102
  self,
104
- model: None | BaseLanguageModel = None,
103
+ model: JanusModel | None = None,
105
104
  max_tokens: int = 4096,
106
105
  protected_node_types: tuple[str, ...] = (),
107
106
  prune_node_types: tuple[str, ...] = (),
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
129
128
  prune_unprotected=prune_unprotected,
130
129
  )
131
130
 
132
- def _get_ast(self, code: str) -> CodeBlock:
131
+ def split_string(self, code: str, name: str) -> CodeBlock:
132
+ # Override split_string to use processed code and track active usings
133
133
  active_usings = self.get_active_usings(code)
134
- code = self.preproccess_assembly(code)
135
- ast: CodeBlock = super()._get_ast(code)
136
- ast.context_tags["active_usings"] = active_usings
137
- return ast
134
+ processed_code = self.preproccess_assembly(code)
135
+ root = super().split_string(processed_code, name)
136
+ if active_usings is not None:
137
+ stack = [root]
138
+ while stack:
139
+ block = stack.pop()
140
+ block.context_tags["active_usings"] = active_usings
141
+ stack.extend(block.children)
142
+ return root
138
143
 
139
144
  def preproccess_assembly(self, code: str) -> str:
140
145
  """Remove non-essential lines from an assembly snippet"""
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
142
147
  lines = code.splitlines()
143
148
  lines = self.strip_header_and_left(lines)
144
149
  lines = self.strip_addresses(lines)
145
- return "".join(str(line) for line in lines)
150
+ return "\n".join(str(line) for line in lines)
146
151
 
147
152
  def get_active_usings(self, code: str) -> Optional[str]:
148
153
  """Look for 'active usings' in the ALC listing header"""
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
15
15
  def setUp(self):
16
16
  model_name = "gpt-4o"
17
17
  self.binary_file = Path("janus/language/binary/_tests/hello")
18
- self.llm, _, _, _ = load_model(model_name)
18
+ self.llm = load_model(model_name)
19
19
  self.splitter = BinarySplitter(model=self.llm)
20
20
  os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
21
21
 
@@ -5,11 +5,11 @@ import tempfile
5
5
  from pathlib import Path
6
6
 
7
7
  import tree_sitter
8
- from langchain.schema.language_model import BaseLanguageModel
9
8
 
10
9
  from janus.language.block import CodeBlock
11
10
  from janus.language.combine import Combiner
12
11
  from janus.language.treesitter import TreeSitterSplitter
12
+ from janus.llm.models_info import JanusModel
13
13
  from janus.utils.enums import LANGUAGES
14
14
  from janus.utils.logger import create_logger
15
15
 
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
31
31
 
32
32
  def __init__(
33
33
  self,
34
- model: None | BaseLanguageModel = None,
34
+ model: JanusModel | None = None,
35
35
  max_tokens: int = 4096,
36
36
  protected_node_types: tuple[str] = (),
37
37
  prune_node_types: tuple[str] = (),
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = MumpsSplitter(model=llm)
17
17
  self.combiner = Combiner(language="mumps")
18
18
  self.test_file = Path("janus/language/mumps/_tests/mumps.m")
@@ -1,11 +1,10 @@
1
1
  import re
2
2
 
3
- from langchain.schema.language_model import BaseLanguageModel
4
-
5
3
  from janus.language.block import CodeBlock
6
4
  from janus.language.combine import Combiner
7
5
  from janus.language.node import NodeType
8
6
  from janus.language.splitter import Splitter
7
+ from janus.llm.models_info import JanusModel
9
8
  from janus.utils.logger import create_logger
10
9
 
11
10
  log = create_logger(__name__)
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
44
43
 
45
44
  def __init__(
46
45
  self,
47
- model: None | BaseLanguageModel = None,
46
+ model: JanusModel | None = None,
48
47
  max_tokens: int = 4096,
49
48
  protected_node_types: tuple[str] = ("routine_definition",),
50
49
  prune_node_types: tuple[str] = (),
@@ -4,11 +4,11 @@ from pathlib import Path
4
4
  from typing import List
5
5
 
6
6
  import tiktoken
7
- from langchain.schema.language_model import BaseLanguageModel
8
7
 
9
8
  from janus.language.block import CodeBlock
10
9
  from janus.language.file import FileManager
11
10
  from janus.language.node import NodeType
11
+ from janus.llm.models_info import JanusModel
12
12
  from janus.utils.logger import create_logger
13
13
 
14
14
  log = create_logger(__name__)
@@ -44,7 +44,7 @@ class Splitter(FileManager):
44
44
  def __init__(
45
45
  self,
46
46
  language: str,
47
- model: None | BaseLanguageModel = None,
47
+ model: JanusModel | None = None,
48
48
  max_tokens: int = 4096,
49
49
  skip_merge: bool = False,
50
50
  protected_node_types: tuple[str, ...] = (),
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
15
  self.maxDiff = None
16
- self.llm, _, _, _ = load_model(model_name)
16
+ self.llm = load_model(model_name)
17
17
 
18
18
  def _split(self):
19
19
  """Split the test file."""
@@ -7,10 +7,10 @@ from typing import Optional
7
7
 
8
8
  import tree_sitter
9
9
  from git import Repo
10
- from langchain.schema.language_model import BaseLanguageModel
11
10
 
12
11
  from janus.language.block import CodeBlock, NodeType
13
12
  from janus.language.splitter import Splitter
13
+ from janus.llm.models_info import JanusModel
14
14
  from janus.utils.enums import LANGUAGES
15
15
  from janus.utils.logger import create_logger
16
16
 
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
25
25
  def __init__(
26
26
  self,
27
27
  language: str,
28
- model: None | BaseLanguageModel = None,
28
+ model: JanusModel | None = None,
29
29
  max_tokens: int = 4096,
30
30
  protected_node_types: tuple[str, ...] = (),
31
31
  prune_node_types: tuple[str, ...] = (),
@@ -12,6 +12,22 @@ from janus.utils.logger import create_logger
12
12
 
13
13
  log = create_logger(__name__)
14
14
 
15
+ openai_model_reroutes = {
16
+ "gpt-4o": "gpt-4o-2024-05-13",
17
+ "gpt-4o-mini": "gpt-4o-mini",
18
+ "gpt-4": "gpt-4-0613",
19
+ "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
20
+ "gpt-4-turbo-preview": "gpt-4-0125-preview",
21
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
22
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
23
+ "gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
24
+ }
25
+
26
+ azure_model_reroutes = {
27
+ "gpt-4o": "gpt-4o-2024-08-06",
28
+ "gpt-4o-mini": "gpt-4o-mini",
29
+ "gpt-3.5-turbo-16k": "gpt35-turbo-16k",
30
+ }
15
31
 
16
32
  # Updated 2024-06-21
17
33
  COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
@@ -20,6 +36,10 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
20
36
  "gpt-4-0125-preview": {"input": 0.01, "output": 0.03},
21
37
  "gpt-4-0613": {"input": 0.03, "output": 0.06},
22
38
  "gpt-4o-2024-05-13": {"input": 0.005, "output": 0.015},
39
+ "gpt-4o-2024-08-06": {"input": 0.00275, "output": 0.011},
40
+ "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
41
+ "gpt35-turbo-16k": {"input": 0.003, "output": 0.004},
42
+ "gpt-35-turbo-16k": {"input": 0.003, "output": 0.004},
23
43
  "anthropic.claude-v2": {"input": 0.008, "output": 0.024},
24
44
  "anthropic.claude-instant-v1": {"input": 0.0008, "output": 0.0024},
25
45
  "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125},
@@ -45,6 +65,8 @@ def _get_token_cost(
45
65
  prompt_tokens: int, completion_tokens: int, model_id: str | None
46
66
  ) -> float:
47
67
  """Get the cost of tokens according to model ID"""
68
+ if model_id in openai_model_reroutes:
69
+ model_id = openai_model_reroutes[model_id]
48
70
  if model_id not in COST_PER_1K_TOKENS:
49
71
  raise ValueError(
50
72
  f"Unknown model: {model_id}. Please provide a valid model name."