janus-llm 3.5.3__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +66 -47
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +15 -10
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/splitter.py +2 -2
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +111 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/RECORD +25 -25
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/converter/diagram.py
CHANGED
@@ -1,14 +1,6 @@
|
|
1
|
-
import
|
1
|
+
from langchain_core.runnables import Runnable, RunnableParallel
|
2
2
|
|
3
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
-
from langchain_core.exceptions import OutputParserException
|
5
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
6
|
-
|
7
|
-
from janus.converter.converter import run_if_changed
|
8
3
|
from janus.converter.document import Documenter
|
9
|
-
from janus.language.block import TranslatedCodeBlock
|
10
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
11
|
-
from janus.parsers.refiner_parser import RefinerParser
|
12
4
|
from janus.parsers.uml import UMLSyntaxParser
|
13
5
|
from janus.utils.logger import create_logger
|
14
6
|
|
@@ -16,10 +8,7 @@ log = create_logger(__name__)
|
|
16
8
|
|
17
9
|
|
18
10
|
class DiagramGenerator(Documenter):
|
19
|
-
"""
|
20
|
-
|
21
|
-
A class that translates code from one programming language to a set of diagrams.
|
22
|
-
"""
|
11
|
+
"""A Converter that translates code into a set of PLANTUML diagrams."""
|
23
12
|
|
24
13
|
def __init__(
|
25
14
|
self,
|
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
|
|
30
19
|
"""Initialize the DiagramGenerator class
|
31
20
|
|
32
21
|
Arguments:
|
33
|
-
model: The LLM to use for translation. If an OpenAI model, the
|
34
|
-
`OPENAI_API_KEY` environment variable must be set and the
|
35
|
-
`OPENAI_ORG_ID` environment variable should be set if needed.
|
36
|
-
model_arguments: Additional arguments to pass to the LLM constructor.
|
37
|
-
source_language: The source programming language.
|
38
|
-
max_prompts: The maximum number of prompts to try before giving up.
|
39
|
-
db_path: path to chroma database
|
40
|
-
db_config: database configuraiton
|
41
22
|
diagram_type: type of PLANTUML diagram to generate
|
23
|
+
add_documentation: Whether to add a documentation step prior to
|
24
|
+
diagram generation.
|
42
25
|
"""
|
43
|
-
super().__init__(**kwargs)
|
44
26
|
self._diagram_type = diagram_type
|
45
27
|
self._add_documentation = add_documentation
|
46
|
-
self._documenter =
|
47
|
-
self._diagram_parser = UMLSyntaxParser(language="plantuml")
|
48
|
-
if add_documentation:
|
49
|
-
self._diagram_prompt_template_name = "diagram_with_documentation"
|
50
|
-
else:
|
51
|
-
self._diagram_prompt_template_name = "diagram"
|
52
|
-
self._load_diagram_prompt_engine()
|
28
|
+
self._documenter = Documenter(**kwargs)
|
53
29
|
|
54
|
-
|
55
|
-
|
56
|
-
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
|
33
|
+
self._parser = UMLSyntaxParser(language="plantuml")
|
57
34
|
|
58
|
-
|
59
|
-
n2 = round((self.max_prompts // n1) ** (1 / 2))
|
35
|
+
self._load_parameters()
|
60
36
|
|
61
|
-
|
62
|
-
|
37
|
+
def _load_prompt(self):
|
38
|
+
super()._load_prompt()
|
39
|
+
self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
|
63
40
|
|
41
|
+
def _input_runnable(self) -> Runnable:
|
64
42
|
if self._add_documentation:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
**{
|
70
|
-
"SOURCE_CODE": input,
|
71
|
-
"DOCUMENTATION": documentation_text,
|
72
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
73
|
-
}
|
74
|
-
),
|
75
|
-
refiner=self._refiner,
|
76
|
-
max_retries=n1,
|
77
|
-
llm=self._llm,
|
78
|
-
)
|
79
|
-
else:
|
80
|
-
refine_output = RefinerParser(
|
81
|
-
parser=self._diagram_parser,
|
82
|
-
initial_prompt=self._diagram_prompt.format(
|
83
|
-
**{
|
84
|
-
"SOURCE_CODE": input,
|
85
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
86
|
-
}
|
87
|
-
),
|
88
|
-
refiner=self._refiner,
|
89
|
-
max_retries=n1,
|
90
|
-
llm=self._llm,
|
43
|
+
return RunnableParallel(
|
44
|
+
SOURCE_CODE=self._parser.parse_input,
|
45
|
+
DOCUMENTATION=self._documenter.chain,
|
46
|
+
context=self._retriever,
|
91
47
|
)
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
max_retries=n2,
|
96
|
-
)
|
97
|
-
completion_chain = self._prompt | self._llm
|
98
|
-
chain = RunnableParallel(
|
99
|
-
completion=completion_chain, prompt_value=self._diagram_prompt
|
100
|
-
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
101
|
-
for _ in range(n3):
|
102
|
-
try:
|
103
|
-
if self._add_documentation:
|
104
|
-
return chain.invoke(
|
105
|
-
{
|
106
|
-
"SOURCE_CODE": input,
|
107
|
-
"DOCUMENTATION": documentation_text,
|
108
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
else:
|
112
|
-
return chain.invoke(
|
113
|
-
{
|
114
|
-
"SOURCE_CODE": input,
|
115
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
116
|
-
}
|
117
|
-
)
|
118
|
-
except OutputParserException:
|
119
|
-
pass
|
120
|
-
|
121
|
-
raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
|
122
|
-
|
123
|
-
@run_if_changed(
|
124
|
-
"_diagram_prompt_template_name",
|
125
|
-
"_source_language",
|
126
|
-
)
|
127
|
-
def _load_diagram_prompt_engine(self) -> None:
|
128
|
-
"""Load the prompt engine according to this instance's attributes.
|
129
|
-
|
130
|
-
If the relevant fields have not been changed since the last time this method was
|
131
|
-
called, nothing happens.
|
132
|
-
"""
|
133
|
-
self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
|
134
|
-
source_language=self._source_language,
|
135
|
-
target_language="text",
|
136
|
-
target_version=None,
|
137
|
-
prompt_template=self._diagram_prompt_template_name,
|
48
|
+
return RunnableParallel(
|
49
|
+
SOURCE_CODE=self._parser.parse_input,
|
50
|
+
context=self._retriever,
|
138
51
|
)
|
139
|
-
self._diagram_prompt = self._diagram_prompt_engine.prompt
|
janus/converter/translate.py
CHANGED
@@ -90,7 +90,7 @@ class Translator(Converter):
|
|
90
90
|
f"({self._source_language} != {self._target_language})"
|
91
91
|
)
|
92
92
|
|
93
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
93
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
94
94
|
source_language=self._source_language,
|
95
95
|
target_language=self._target_language,
|
96
96
|
target_version=self._target_version,
|
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = AlcSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="ibmhlasm")
|
18
18
|
self.test_file = Path("janus/language/alc/_tests/alc.asm")
|
janus/language/alc/alc.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
import re
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from langchain.schema.language_model import BaseLanguageModel
|
5
|
-
|
6
4
|
from janus.language.block import CodeBlock
|
7
5
|
from janus.language.combine import Combiner
|
8
6
|
from janus.language.node import NodeType
|
9
7
|
from janus.language.treesitter import TreeSitterSplitter
|
8
|
+
from janus.llm.models_info import JanusModel
|
10
9
|
from janus.utils.logger import create_logger
|
11
10
|
|
12
11
|
log = create_logger(__name__)
|
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
|
|
27
26
|
|
28
27
|
def __init__(
|
29
28
|
self,
|
30
|
-
model:
|
29
|
+
model: JanusModel | None = None,
|
31
30
|
max_tokens: int = 4096,
|
32
31
|
protected_node_types: tuple[str, ...] = (),
|
33
32
|
prune_node_types: tuple[str, ...] = (),
|
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
101
100
|
|
102
101
|
def __init__(
|
103
102
|
self,
|
104
|
-
model:
|
103
|
+
model: JanusModel | None = None,
|
105
104
|
max_tokens: int = 4096,
|
106
105
|
protected_node_types: tuple[str, ...] = (),
|
107
106
|
prune_node_types: tuple[str, ...] = (),
|
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
|
|
129
128
|
prune_unprotected=prune_unprotected,
|
130
129
|
)
|
131
130
|
|
132
|
-
def
|
131
|
+
def split_string(self, code: str, name: str) -> CodeBlock:
|
132
|
+
# Override split_string to use processed code and track active usings
|
133
133
|
active_usings = self.get_active_usings(code)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
134
|
+
processed_code = self.preproccess_assembly(code)
|
135
|
+
root = super().split_string(processed_code, name)
|
136
|
+
if active_usings is not None:
|
137
|
+
stack = [root]
|
138
|
+
while stack:
|
139
|
+
block = stack.pop()
|
140
|
+
block.context_tags["active_usings"] = active_usings
|
141
|
+
stack.extend(block.children)
|
142
|
+
return root
|
138
143
|
|
139
144
|
def preproccess_assembly(self, code: str) -> str:
|
140
145
|
"""Remove non-essential lines from an assembly snippet"""
|
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
142
147
|
lines = code.splitlines()
|
143
148
|
lines = self.strip_header_and_left(lines)
|
144
149
|
lines = self.strip_addresses(lines)
|
145
|
-
return "".join(str(line) for line in lines)
|
150
|
+
return "\n".join(str(line) for line in lines)
|
146
151
|
|
147
152
|
def get_active_usings(self, code: str) -> Optional[str]:
|
148
153
|
"""Look for 'active usings' in the ALC listing header"""
|
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
|
|
15
15
|
def setUp(self):
|
16
16
|
model_name = "gpt-4o"
|
17
17
|
self.binary_file = Path("janus/language/binary/_tests/hello")
|
18
|
-
self.llm
|
18
|
+
self.llm = load_model(model_name)
|
19
19
|
self.splitter = BinarySplitter(model=self.llm)
|
20
20
|
os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
|
21
21
|
|
janus/language/binary/binary.py
CHANGED
@@ -5,11 +5,11 @@ import tempfile
|
|
5
5
|
from pathlib import Path
|
6
6
|
|
7
7
|
import tree_sitter
|
8
|
-
from langchain.schema.language_model import BaseLanguageModel
|
9
8
|
|
10
9
|
from janus.language.block import CodeBlock
|
11
10
|
from janus.language.combine import Combiner
|
12
11
|
from janus.language.treesitter import TreeSitterSplitter
|
12
|
+
from janus.llm.models_info import JanusModel
|
13
13
|
from janus.utils.enums import LANGUAGES
|
14
14
|
from janus.utils.logger import create_logger
|
15
15
|
|
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
|
|
31
31
|
|
32
32
|
def __init__(
|
33
33
|
self,
|
34
|
-
model:
|
34
|
+
model: JanusModel | None = None,
|
35
35
|
max_tokens: int = 4096,
|
36
36
|
protected_node_types: tuple[str] = (),
|
37
37
|
prune_node_types: tuple[str] = (),
|
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = MumpsSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="mumps")
|
18
18
|
self.test_file = Path("janus/language/mumps/_tests/mumps.m")
|
janus/language/mumps/mumps.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
import re
|
2
2
|
|
3
|
-
from langchain.schema.language_model import BaseLanguageModel
|
4
|
-
|
5
3
|
from janus.language.block import CodeBlock
|
6
4
|
from janus.language.combine import Combiner
|
7
5
|
from janus.language.node import NodeType
|
8
6
|
from janus.language.splitter import Splitter
|
7
|
+
from janus.llm.models_info import JanusModel
|
9
8
|
from janus.utils.logger import create_logger
|
10
9
|
|
11
10
|
log = create_logger(__name__)
|
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
|
|
44
43
|
|
45
44
|
def __init__(
|
46
45
|
self,
|
47
|
-
model:
|
46
|
+
model: JanusModel | None = None,
|
48
47
|
max_tokens: int = 4096,
|
49
48
|
protected_node_types: tuple[str] = ("routine_definition",),
|
50
49
|
prune_node_types: tuple[str] = (),
|
janus/language/splitter.py
CHANGED
@@ -4,11 +4,11 @@ from pathlib import Path
|
|
4
4
|
from typing import List
|
5
5
|
|
6
6
|
import tiktoken
|
7
|
-
from langchain.schema.language_model import BaseLanguageModel
|
8
7
|
|
9
8
|
from janus.language.block import CodeBlock
|
10
9
|
from janus.language.file import FileManager
|
11
10
|
from janus.language.node import NodeType
|
11
|
+
from janus.llm.models_info import JanusModel
|
12
12
|
from janus.utils.logger import create_logger
|
13
13
|
|
14
14
|
log = create_logger(__name__)
|
@@ -44,7 +44,7 @@ class Splitter(FileManager):
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
46
|
language: str,
|
47
|
-
model:
|
47
|
+
model: JanusModel | None = None,
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
skip_merge: bool = False,
|
50
50
|
protected_node_types: tuple[str, ...] = (),
|
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
|
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
15
|
self.maxDiff = None
|
16
|
-
self.llm
|
16
|
+
self.llm = load_model(model_name)
|
17
17
|
|
18
18
|
def _split(self):
|
19
19
|
"""Split the test file."""
|
@@ -7,10 +7,10 @@ from typing import Optional
|
|
7
7
|
|
8
8
|
import tree_sitter
|
9
9
|
from git import Repo
|
10
|
-
from langchain.schema.language_model import BaseLanguageModel
|
11
10
|
|
12
11
|
from janus.language.block import CodeBlock, NodeType
|
13
12
|
from janus.language.splitter import Splitter
|
13
|
+
from janus.llm.models_info import JanusModel
|
14
14
|
from janus.utils.enums import LANGUAGES
|
15
15
|
from janus.utils.logger import create_logger
|
16
16
|
|
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
|
|
25
25
|
def __init__(
|
26
26
|
self,
|
27
27
|
language: str,
|
28
|
-
model:
|
28
|
+
model: JanusModel | None = None,
|
29
29
|
max_tokens: int = 4096,
|
30
30
|
protected_node_types: tuple[str, ...] = (),
|
31
31
|
prune_node_types: tuple[str, ...] = (),
|
janus/llm/model_callbacks.py
CHANGED
@@ -12,6 +12,17 @@ from janus.utils.logger import create_logger
|
|
12
12
|
|
13
13
|
log = create_logger(__name__)
|
14
14
|
|
15
|
+
openai_model_reroutes = {
|
16
|
+
"gpt-4o": "gpt-4o-2024-05-13",
|
17
|
+
"gpt-4o-mini": "gpt-4o-mini",
|
18
|
+
"gpt-4": "gpt-4-0613",
|
19
|
+
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
20
|
+
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
21
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
22
|
+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
23
|
+
"gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
|
24
|
+
}
|
25
|
+
|
15
26
|
|
16
27
|
# Updated 2024-06-21
|
17
28
|
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
@@ -45,6 +56,8 @@ def _get_token_cost(
|
|
45
56
|
prompt_tokens: int, completion_tokens: int, model_id: str | None
|
46
57
|
) -> float:
|
47
58
|
"""Get the cost of tokens according to model ID"""
|
59
|
+
if model_id in openai_model_reroutes:
|
60
|
+
model_id = openai_model_reroutes[model_id]
|
48
61
|
if model_id not in COST_PER_1K_TOKENS:
|
49
62
|
raise ValueError(
|
50
63
|
f"Unknown model: {model_id}. Please provide a valid model name."
|
janus/llm/models_info.py
CHANGED
@@ -2,14 +2,14 @@ import json
|
|
2
2
|
import os
|
3
3
|
import time
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import
|
5
|
+
from typing import Protocol, TypeVar
|
6
6
|
|
7
7
|
from dotenv import load_dotenv
|
8
8
|
from langchain_community.llms import HuggingFaceTextGenInference
|
9
|
-
from langchain_core.
|
9
|
+
from langchain_core.runnables import Runnable
|
10
10
|
from langchain_openai import ChatOpenAI
|
11
11
|
|
12
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
12
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
|
13
13
|
from janus.prompts.prompt import (
|
14
14
|
ChatGptPromptEngine,
|
15
15
|
ClaudePromptEngine,
|
@@ -44,17 +44,34 @@ except ImportError:
|
|
44
44
|
)
|
45
45
|
|
46
46
|
|
47
|
+
ModelType = TypeVar(
|
48
|
+
"ModelType",
|
49
|
+
ChatOpenAI,
|
50
|
+
HuggingFaceTextGenInference,
|
51
|
+
Bedrock,
|
52
|
+
BedrockChat,
|
53
|
+
HuggingFacePipeline,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class JanusModelProtocol(Protocol):
|
58
|
+
model_id: str
|
59
|
+
model_type_name: str
|
60
|
+
token_limit: int
|
61
|
+
input_token_cost: float
|
62
|
+
output_token_cost: float
|
63
|
+
prompt_engine: type[PromptEngine]
|
64
|
+
|
65
|
+
def get_num_tokens(self, text: str) -> int:
|
66
|
+
...
|
67
|
+
|
68
|
+
|
69
|
+
class JanusModel(Runnable, JanusModelProtocol):
|
70
|
+
...
|
71
|
+
|
72
|
+
|
47
73
|
load_dotenv()
|
48
74
|
|
49
|
-
openai_model_reroutes = {
|
50
|
-
"gpt-4o": "gpt-4o-2024-05-13",
|
51
|
-
"gpt-4o-mini": "gpt-4o-mini",
|
52
|
-
"gpt-4": "gpt-4-0613",
|
53
|
-
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
54
|
-
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
55
|
-
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
56
|
-
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
57
|
-
}
|
58
75
|
|
59
76
|
openai_models = [
|
60
77
|
"gpt-4o",
|
@@ -105,24 +122,15 @@ bedrock_models = [
|
|
105
122
|
]
|
106
123
|
all_models = [*openai_models, *bedrock_models]
|
107
124
|
|
108
|
-
MODEL_TYPE_CONSTRUCTORS: dict[str,
|
125
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
109
126
|
"OpenAI": ChatOpenAI,
|
110
127
|
"HuggingFace": HuggingFaceTextGenInference,
|
128
|
+
"Bedrock": Bedrock,
|
129
|
+
"BedrockChat": BedrockChat,
|
130
|
+
"HuggingFaceLocal": HuggingFacePipeline,
|
111
131
|
}
|
112
132
|
|
113
|
-
|
114
|
-
MODEL_TYPE_CONSTRUCTORS.update(
|
115
|
-
{
|
116
|
-
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
117
|
-
"Bedrock": Bedrock,
|
118
|
-
"BedrockChat": BedrockChat,
|
119
|
-
}
|
120
|
-
)
|
121
|
-
except NameError:
|
122
|
-
pass
|
123
|
-
|
124
|
-
|
125
|
-
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
133
|
+
MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
|
126
134
|
**{m: ChatGptPromptEngine for m in openai_models},
|
127
135
|
**{m: ClaudePromptEngine for m in claude_models},
|
128
136
|
**{m: Llama2PromptEngine for m in llama2_models},
|
@@ -132,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
132
140
|
**{m: MistralPromptEngine for m in mistral_models},
|
133
141
|
}
|
134
142
|
|
135
|
-
_open_ai_defaults: dict[str, str] = {
|
136
|
-
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
137
|
-
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
138
|
-
}
|
139
|
-
|
140
143
|
MODEL_ID_TO_LONG_ID = {
|
141
144
|
**{m: mr for m, mr in openai_model_reroutes.items()},
|
142
145
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
@@ -168,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
168
171
|
|
169
172
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
170
173
|
|
171
|
-
MODEL_TYPES: dict[str,
|
174
|
+
MODEL_TYPES: dict[str, str] = {
|
172
175
|
**{m: "OpenAI" for m in openai_models},
|
173
176
|
**{m: "BedrockChat" for m in bedrock_models},
|
174
177
|
}
|
@@ -211,53 +214,90 @@ def get_available_model_names() -> list[str]:
|
|
211
214
|
return avaialable_models
|
212
215
|
|
213
216
|
|
214
|
-
def load_model(
|
215
|
-
user_model_name: str,
|
216
|
-
) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
|
217
|
+
def load_model(model_id) -> JanusModel:
|
217
218
|
if not MODEL_CONFIG_DIR.exists():
|
218
219
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
219
|
-
model_config_file = MODEL_CONFIG_DIR / f"{
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
f"default models for {user_model_name}."
|
224
|
-
)
|
225
|
-
model_id = user_model_name
|
226
|
-
if user_model_name not in DEFAULT_MODELS:
|
227
|
-
message = (
|
228
|
-
f"Model {user_model_name} not found in default models. Make sure to run "
|
229
|
-
"`janus llm add` first."
|
230
|
-
)
|
231
|
-
log.error(message)
|
232
|
-
raise ValueError(message)
|
233
|
-
model_config = {
|
234
|
-
"model_type": MODEL_TYPES[model_id],
|
235
|
-
"model_id": model_id,
|
236
|
-
"model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
|
237
|
-
"token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
|
238
|
-
"model_cost": COST_PER_1K_TOKENS.get(
|
239
|
-
MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
|
240
|
-
),
|
241
|
-
}
|
242
|
-
with open(model_config_file, "w") as f:
|
243
|
-
json.dump(model_config, f)
|
244
|
-
else:
|
220
|
+
model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
|
221
|
+
|
222
|
+
if model_config_file.exists():
|
223
|
+
log.info(f"Loading {model_id} from {model_config_file}.")
|
245
224
|
with open(model_config_file, "r") as f:
|
246
225
|
model_config = json.load(f)
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
226
|
+
model_type_name = model_config["model_type"]
|
227
|
+
model_id = model_config["model_id"]
|
228
|
+
model_args = model_config["model_args"]
|
229
|
+
token_limit = model_config["token_limit"]
|
230
|
+
input_token_cost = model_config["model_cost"]["input"]
|
231
|
+
output_token_cost = model_config["model_cost"]["output"]
|
232
|
+
|
233
|
+
elif model_id in DEFAULT_MODELS:
|
234
|
+
model_id = model_id
|
235
|
+
model_long_id = MODEL_ID_TO_LONG_ID[model_id]
|
236
|
+
model_type_name = MODEL_TYPES[model_id]
|
237
|
+
model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
|
238
|
+
|
239
|
+
token_limit = 0
|
240
|
+
input_token_cost = 0.0
|
241
|
+
output_token_cost = 0.0
|
242
|
+
if model_long_id in TOKEN_LIMITS:
|
243
|
+
token_limit = TOKEN_LIMITS[model_long_id]
|
244
|
+
if model_long_id in COST_PER_1K_TOKENS:
|
245
|
+
token_limits = COST_PER_1K_TOKENS[model_long_id]
|
246
|
+
input_token_cost = token_limits["input"]
|
247
|
+
output_token_cost = token_limits["output"]
|
248
|
+
|
249
|
+
else:
|
250
|
+
model_list = "\n\t".join(DEFAULT_MODELS)
|
251
|
+
message = (
|
252
|
+
f"Model {model_id} not found in user-defined model directory "
|
253
|
+
f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
|
254
|
+
f"models:\n\t{model_list}\n"
|
255
|
+
f"To use a custom model, first run `janus llm add`."
|
256
|
+
)
|
257
|
+
log.error(message)
|
258
|
+
raise ValueError(message)
|
259
|
+
|
260
|
+
if model_type_name == "HuggingFaceLocal":
|
261
|
+
model = HuggingFacePipeline.from_model_id(
|
262
|
+
model_id=model_id,
|
263
|
+
task="text-generation",
|
264
|
+
model_kwargs=model_args,
|
265
|
+
)
|
266
|
+
model_args.update(pipeline=model.pipeline)
|
267
|
+
|
268
|
+
elif model_type_name == "OpenAI":
|
269
|
+
model_args.update(
|
270
|
+
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
271
|
+
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
272
|
+
)
|
251
273
|
log.warning("Do NOT use this model in sensitive environments!")
|
252
274
|
log.warning("If you would like to cancel, please press Ctrl+C.")
|
253
275
|
log.warning("Waiting 10 seconds...")
|
254
276
|
# Give enough time for the user to read the warnings and cancel
|
255
277
|
time.sleep(10)
|
256
278
|
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
279
|
+
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
280
|
+
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
281
|
+
|
282
|
+
class JanusModel(model_type):
|
283
|
+
model_id: str
|
284
|
+
short_model_id: str
|
285
|
+
model_type_name: str
|
286
|
+
token_limit: int
|
287
|
+
input_token_cost: float
|
288
|
+
output_token_cost: float
|
289
|
+
prompt_engine: type[PromptEngine]
|
290
|
+
|
291
|
+
model_args.update(
|
292
|
+
model_id=MODEL_ID_TO_LONG_ID[model_id],
|
293
|
+
short_model_id=model_id,
|
294
|
+
)
|
295
|
+
|
296
|
+
return JanusModel(
|
297
|
+
model_type_name=model_type_name,
|
298
|
+
token_limit=token_limit,
|
299
|
+
input_token_cost=input_token_cost,
|
300
|
+
output_token_cost=output_token_cost,
|
301
|
+
prompt_engine=prompt_engine,
|
302
|
+
**model_args,
|
263
303
|
)
|