janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,6 @@
1
- import math
1
+ from langchain_core.runnables import Runnable, RunnableParallel
2
2
 
3
- from langchain.output_parsers import RetryWithErrorOutputParser
4
- from langchain_core.exceptions import OutputParserException
5
- from langchain_core.runnables import RunnableLambda, RunnableParallel
6
-
7
- from janus.converter.converter import run_if_changed
8
3
  from janus.converter.document import Documenter
9
- from janus.language.block import TranslatedCodeBlock
10
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
11
- from janus.parsers.refiner_parser import RefinerParser
12
4
  from janus.parsers.uml import UMLSyntaxParser
13
5
  from janus.utils.logger import create_logger
14
6
 
@@ -16,10 +8,7 @@ log = create_logger(__name__)
16
8
 
17
9
 
18
10
  class DiagramGenerator(Documenter):
19
- """DiagramGenerator
20
-
21
- A class that translates code from one programming language to a set of diagrams.
22
- """
11
+ """A Converter that translates code into a set of PLANTUML diagrams."""
23
12
 
24
13
  def __init__(
25
14
  self,
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
30
19
  """Initialize the DiagramGenerator class
31
20
 
32
21
  Arguments:
33
- model: The LLM to use for translation. If an OpenAI model, the
34
- `OPENAI_API_KEY` environment variable must be set and the
35
- `OPENAI_ORG_ID` environment variable should be set if needed.
36
- model_arguments: Additional arguments to pass to the LLM constructor.
37
- source_language: The source programming language.
38
- max_prompts: The maximum number of prompts to try before giving up.
39
- db_path: path to chroma database
40
- db_config: database configuraiton
41
22
  diagram_type: type of PLANTUML diagram to generate
23
+ add_documentation: Whether to add a documentation step prior to
24
+ diagram generation.
42
25
  """
43
- super().__init__(**kwargs)
44
26
  self._diagram_type = diagram_type
45
27
  self._add_documentation = add_documentation
46
- self._documenter = None
47
- self._diagram_parser = UMLSyntaxParser(language="plantuml")
48
- if add_documentation:
49
- self._diagram_prompt_template_name = "diagram_with_documentation"
50
- else:
51
- self._diagram_prompt_template_name = "diagram"
52
- self._load_diagram_prompt_engine()
28
+ self._documenter = Documenter(**kwargs)
53
29
 
54
- def _run_chain(self, block: TranslatedCodeBlock) -> str:
55
- input = self._parser.parse_input(block.original)
56
- n1 = round(self.max_prompts ** (1 / 3))
30
+ super().__init__(**kwargs)
31
+
32
+ self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
33
+ self._parser = UMLSyntaxParser(language="plantuml")
57
34
 
58
- # Retries with the input, output, and error
59
- n2 = round((self.max_prompts // n1) ** (1 / 2))
35
+ self._load_parameters()
60
36
 
61
- # Retries with just the input
62
- n3 = math.ceil(self.max_prompts / (n1 * n2))
37
+ def _load_prompt(self):
38
+ super()._load_prompt()
39
+ self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
63
40
 
41
+ def _input_runnable(self) -> Runnable:
64
42
  if self._add_documentation:
65
- documentation_text = super()._run_chain(block)
66
- refine_output = RefinerParser(
67
- parser=self._diagram_parser,
68
- initial_prompt=self._diagram_prompt.format(
69
- **{
70
- "SOURCE_CODE": input,
71
- "DOCUMENTATION": documentation_text,
72
- "DIAGRAM_TYPE": self._diagram_type,
73
- }
74
- ),
75
- refiner=self._refiner,
76
- max_retries=n1,
77
- llm=self._llm,
78
- )
79
- else:
80
- refine_output = RefinerParser(
81
- parser=self._diagram_parser,
82
- initial_prompt=self._diagram_prompt.format(
83
- **{
84
- "SOURCE_CODE": input,
85
- "DIAGRAM_TYPE": self._diagram_type,
86
- }
87
- ),
88
- refiner=self._refiner,
89
- max_retries=n1,
90
- llm=self._llm,
43
+ return RunnableParallel(
44
+ SOURCE_CODE=self._parser.parse_input,
45
+ DOCUMENTATION=self._documenter.chain,
46
+ context=self._retriever,
91
47
  )
92
- retry = RetryWithErrorOutputParser.from_llm(
93
- llm=self._llm,
94
- parser=refine_output,
95
- max_retries=n2,
96
- )
97
- completion_chain = self._prompt | self._llm
98
- chain = RunnableParallel(
99
- completion=completion_chain, prompt_value=self._diagram_prompt
100
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
101
- for _ in range(n3):
102
- try:
103
- if self._add_documentation:
104
- return chain.invoke(
105
- {
106
- "SOURCE_CODE": input,
107
- "DOCUMENTATION": documentation_text,
108
- "DIAGRAM_TYPE": self._diagram_type,
109
- }
110
- )
111
- else:
112
- return chain.invoke(
113
- {
114
- "SOURCE_CODE": input,
115
- "DIAGRAM_TYPE": self._diagram_type,
116
- }
117
- )
118
- except OutputParserException:
119
- pass
120
-
121
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
122
-
123
- @run_if_changed(
124
- "_diagram_prompt_template_name",
125
- "_source_language",
126
- )
127
- def _load_diagram_prompt_engine(self) -> None:
128
- """Load the prompt engine according to this instance's attributes.
129
-
130
- If the relevant fields have not been changed since the last time this method was
131
- called, nothing happens.
132
- """
133
- self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
134
- source_language=self._source_language,
135
- target_language="text",
136
- target_version=None,
137
- prompt_template=self._diagram_prompt_template_name,
48
+ return RunnableParallel(
49
+ SOURCE_CODE=self._parser.parse_input,
50
+ context=self._retriever,
138
51
  )
139
- self._diagram_prompt = self._diagram_prompt_engine.prompt
@@ -90,7 +90,7 @@ class Translator(Converter):
90
90
  f"({self._source_language} != {self._target_language})"
91
91
  )
92
92
 
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
93
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
94
94
  source_language=self._source_language,
95
95
  target_language=self._target_language,
96
96
  target_version=self._target_version,
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = AlcSplitter(model=llm)
17
17
  self.combiner = Combiner(language="ibmhlasm")
18
18
  self.test_file = Path("janus/language/alc/_tests/alc.asm")
janus/language/alc/alc.py CHANGED
@@ -1,12 +1,11 @@
1
1
  import re
2
2
  from typing import Optional
3
3
 
4
- from langchain.schema.language_model import BaseLanguageModel
5
-
6
4
  from janus.language.block import CodeBlock
7
5
  from janus.language.combine import Combiner
8
6
  from janus.language.node import NodeType
9
7
  from janus.language.treesitter import TreeSitterSplitter
8
+ from janus.llm.models_info import JanusModel
10
9
  from janus.utils.logger import create_logger
11
10
 
12
11
  log = create_logger(__name__)
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
27
26
 
28
27
  def __init__(
29
28
  self,
30
- model: None | BaseLanguageModel = None,
29
+ model: JanusModel | None = None,
31
30
  max_tokens: int = 4096,
32
31
  protected_node_types: tuple[str, ...] = (),
33
32
  prune_node_types: tuple[str, ...] = (),
@@ -63,7 +62,7 @@ class AlcSplitter(TreeSitterSplitter):
63
62
  # instruction and containing all the subsequent nodes up until the
64
63
  # next csect or dsect instruction
65
64
  sects: list[list[CodeBlock]] = [[]]
66
- for c in block.children:
65
+ for c in sorted(block.children):
67
66
  if c.node_type == "csect_instruction":
68
67
  c.context_tags["alc_section"] = "CSECT"
69
68
  sects.append([c])
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
101
100
 
102
101
  def __init__(
103
102
  self,
104
- model: None | BaseLanguageModel = None,
103
+ model: JanusModel | None = None,
105
104
  max_tokens: int = 4096,
106
105
  protected_node_types: tuple[str, ...] = (),
107
106
  prune_node_types: tuple[str, ...] = (),
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
129
128
  prune_unprotected=prune_unprotected,
130
129
  )
131
130
 
132
- def _get_ast(self, code: str) -> CodeBlock:
131
+ def split_string(self, code: str, name: str) -> CodeBlock:
132
+ # Override split_string to use processed code and track active usings
133
133
  active_usings = self.get_active_usings(code)
134
- code = self.preproccess_assembly(code)
135
- ast: CodeBlock = super()._get_ast(code)
136
- ast.context_tags["active_usings"] = active_usings
137
- return ast
134
+ processed_code = self.preproccess_assembly(code)
135
+ root = super().split_string(processed_code, name)
136
+ if active_usings is not None:
137
+ stack = [root]
138
+ while stack:
139
+ block = stack.pop()
140
+ block.context_tags["active_usings"] = active_usings
141
+ stack.extend(block.children)
142
+ return root
138
143
 
139
144
  def preproccess_assembly(self, code: str) -> str:
140
145
  """Remove non-essential lines from an assembly snippet"""
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
142
147
  lines = code.splitlines()
143
148
  lines = self.strip_header_and_left(lines)
144
149
  lines = self.strip_addresses(lines)
145
- return "".join(str(line) for line in lines)
150
+ return "\n".join(str(line) for line in lines)
146
151
 
147
152
  def get_active_usings(self, code: str) -> Optional[str]:
148
153
  """Look for 'active usings' in the ALC listing header"""
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
15
15
  def setUp(self):
16
16
  model_name = "gpt-4o"
17
17
  self.binary_file = Path("janus/language/binary/_tests/hello")
18
- self.llm, _, _, _ = load_model(model_name)
18
+ self.llm = load_model(model_name)
19
19
  self.splitter = BinarySplitter(model=self.llm)
20
20
  os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
21
21
 
@@ -5,11 +5,11 @@ import tempfile
5
5
  from pathlib import Path
6
6
 
7
7
  import tree_sitter
8
- from langchain.schema.language_model import BaseLanguageModel
9
8
 
10
9
  from janus.language.block import CodeBlock
11
10
  from janus.language.combine import Combiner
12
11
  from janus.language.treesitter import TreeSitterSplitter
12
+ from janus.llm.models_info import JanusModel
13
13
  from janus.utils.enums import LANGUAGES
14
14
  from janus.utils.logger import create_logger
15
15
 
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
31
31
 
32
32
  def __init__(
33
33
  self,
34
- model: None | BaseLanguageModel = None,
34
+ model: JanusModel | None = None,
35
35
  max_tokens: int = 4096,
36
36
  protected_node_types: tuple[str] = (),
37
37
  prune_node_types: tuple[str] = (),
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
12
12
  def setUp(self):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
- llm, _, _, _ = load_model(model_name)
15
+ llm = load_model(model_name)
16
16
  self.splitter = MumpsSplitter(model=llm)
17
17
  self.combiner = Combiner(language="mumps")
18
18
  self.test_file = Path("janus/language/mumps/_tests/mumps.m")
@@ -1,11 +1,10 @@
1
1
  import re
2
2
 
3
- from langchain.schema.language_model import BaseLanguageModel
4
-
5
3
  from janus.language.block import CodeBlock
6
4
  from janus.language.combine import Combiner
7
5
  from janus.language.node import NodeType
8
6
  from janus.language.splitter import Splitter
7
+ from janus.llm.models_info import JanusModel
9
8
  from janus.utils.logger import create_logger
10
9
 
11
10
  log = create_logger(__name__)
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
44
43
 
45
44
  def __init__(
46
45
  self,
47
- model: None | BaseLanguageModel = None,
46
+ model: JanusModel | None = None,
48
47
  max_tokens: int = 4096,
49
48
  protected_node_types: tuple[str] = ("routine_definition",),
50
49
  prune_node_types: tuple[str] = (),
@@ -19,6 +19,7 @@ def get_flexible_ast(language: str, **kwargs) -> Splitter:
19
19
  Returns:
20
20
  A flexible AST splitter for the given language.
21
21
  """
22
+ kwargs.update(protected_node_types=())
22
23
  if language == "ibmhlasm":
23
24
  return AlcSplitter(**kwargs)
24
25
  elif language == "mumps":
@@ -28,7 +29,7 @@ def get_flexible_ast(language: str, **kwargs) -> Splitter:
28
29
 
29
30
 
30
31
  @register_splitter("ast-strict")
31
- def get_strict_ast(language: str, **kwargs) -> Splitter:
32
+ def get_strict_ast(language: str, prune_unprotected=True, **kwargs) -> Splitter:
32
33
  """Get a strict AST splitter for the given language.
33
34
 
34
35
  The strict splitter will only return nodes that are of a functional type.
@@ -41,7 +42,7 @@ def get_strict_ast(language: str, **kwargs) -> Splitter:
41
42
  """
42
43
  kwargs.update(
43
44
  protected_node_types=LANGUAGES[language]["functional_node_types"],
44
- prune_unprotected=True,
45
+ prune_unprotected=prune_unprotected,
45
46
  )
46
47
  if language == "ibmhlasm":
47
48
  return AlcSplitter(**kwargs)
@@ -4,11 +4,11 @@ from pathlib import Path
4
4
  from typing import List
5
5
 
6
6
  import tiktoken
7
- from langchain.schema.language_model import BaseLanguageModel
8
7
 
9
8
  from janus.language.block import CodeBlock
10
9
  from janus.language.file import FileManager
11
10
  from janus.language.node import NodeType
11
+ from janus.llm.models_info import JanusModel
12
12
  from janus.utils.logger import create_logger
13
13
 
14
14
  log = create_logger(__name__)
@@ -44,7 +44,7 @@ class Splitter(FileManager):
44
44
  def __init__(
45
45
  self,
46
46
  language: str,
47
- model: None | BaseLanguageModel = None,
47
+ model: JanusModel | None = None,
48
48
  max_tokens: int = 4096,
49
49
  skip_merge: bool = False,
50
50
  protected_node_types: tuple[str, ...] = (),
@@ -387,7 +387,10 @@ class Splitter(FileManager):
387
387
  return
388
388
 
389
389
  if self._is_protected(node):
390
- raise TokenLimitError(r"Irreducible node too large for context!")
390
+ log.error(
391
+ "Protected node too large for context!"
392
+ f" ({node.tokens} > {self.max_tokens})"
393
+ )
391
394
 
392
395
  if node.children:
393
396
  for child in node.children:
@@ -423,7 +426,7 @@ class Splitter(FileManager):
423
426
  name = f"{node.name}-L#{node_line}"
424
427
  tokens = self._count_tokens(line)
425
428
  if tokens > self.max_tokens:
426
- raise TokenLimitError(
429
+ log.error(
427
430
  "Irreducible node too large for context!"
428
431
  f" ({tokens} > {self.max_tokens})"
429
432
  )
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
13
13
  """Set up the tests."""
14
14
  model_name = "gpt-4o"
15
15
  self.maxDiff = None
16
- self.llm, _, _, _ = load_model(model_name)
16
+ self.llm = load_model(model_name)
17
17
 
18
18
  def _split(self):
19
19
  """Split the test file."""
@@ -7,10 +7,10 @@ from typing import Optional
7
7
 
8
8
  import tree_sitter
9
9
  from git import Repo
10
- from langchain.schema.language_model import BaseLanguageModel
11
10
 
12
11
  from janus.language.block import CodeBlock, NodeType
13
12
  from janus.language.splitter import Splitter
13
+ from janus.llm.models_info import JanusModel
14
14
  from janus.utils.enums import LANGUAGES
15
15
  from janus.utils.logger import create_logger
16
16
 
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
25
25
  def __init__(
26
26
  self,
27
27
  language: str,
28
- model: None | BaseLanguageModel = None,
28
+ model: JanusModel | None = None,
29
29
  max_tokens: int = 4096,
30
30
  protected_node_types: tuple[str, ...] = (),
31
31
  prune_node_types: tuple[str, ...] = (),
@@ -12,6 +12,17 @@ from janus.utils.logger import create_logger
12
12
 
13
13
  log = create_logger(__name__)
14
14
 
15
+ openai_model_reroutes = {
16
+ "gpt-4o": "gpt-4o-2024-05-13",
17
+ "gpt-4o-mini": "gpt-4o-mini",
18
+ "gpt-4": "gpt-4-0613",
19
+ "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
20
+ "gpt-4-turbo-preview": "gpt-4-0125-preview",
21
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
22
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
23
+ "gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
24
+ }
25
+
15
26
 
16
27
  # Updated 2024-06-21
17
28
  COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
@@ -45,6 +56,8 @@ def _get_token_cost(
45
56
  prompt_tokens: int, completion_tokens: int, model_id: str | None
46
57
  ) -> float:
47
58
  """Get the cost of tokens according to model ID"""
59
+ if model_id in openai_model_reroutes:
60
+ model_id = openai_model_reroutes[model_id]
48
61
  if model_id not in COST_PER_1K_TOKENS:
49
62
  raise ValueError(
50
63
  f"Unknown model: {model_id}. Please provide a valid model name."