janus-llm 3.5.3__py3-none-any.whl → 4.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +66 -47
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +15 -10
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/splitter.py +2 -2
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +111 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/RECORD +25 -25
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.3.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
janus/converter/diagram.py
CHANGED
@@ -1,14 +1,6 @@
|
|
1
|
-
import
|
1
|
+
from langchain_core.runnables import Runnable, RunnableParallel
|
2
2
|
|
3
|
-
from langchain.output_parsers import RetryWithErrorOutputParser
|
4
|
-
from langchain_core.exceptions import OutputParserException
|
5
|
-
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
6
|
-
|
7
|
-
from janus.converter.converter import run_if_changed
|
8
3
|
from janus.converter.document import Documenter
|
9
|
-
from janus.language.block import TranslatedCodeBlock
|
10
|
-
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
11
|
-
from janus.parsers.refiner_parser import RefinerParser
|
12
4
|
from janus.parsers.uml import UMLSyntaxParser
|
13
5
|
from janus.utils.logger import create_logger
|
14
6
|
|
@@ -16,10 +8,7 @@ log = create_logger(__name__)
|
|
16
8
|
|
17
9
|
|
18
10
|
class DiagramGenerator(Documenter):
|
19
|
-
"""
|
20
|
-
|
21
|
-
A class that translates code from one programming language to a set of diagrams.
|
22
|
-
"""
|
11
|
+
"""A Converter that translates code into a set of PLANTUML diagrams."""
|
23
12
|
|
24
13
|
def __init__(
|
25
14
|
self,
|
@@ -30,110 +19,33 @@ class DiagramGenerator(Documenter):
|
|
30
19
|
"""Initialize the DiagramGenerator class
|
31
20
|
|
32
21
|
Arguments:
|
33
|
-
model: The LLM to use for translation. If an OpenAI model, the
|
34
|
-
`OPENAI_API_KEY` environment variable must be set and the
|
35
|
-
`OPENAI_ORG_ID` environment variable should be set if needed.
|
36
|
-
model_arguments: Additional arguments to pass to the LLM constructor.
|
37
|
-
source_language: The source programming language.
|
38
|
-
max_prompts: The maximum number of prompts to try before giving up.
|
39
|
-
db_path: path to chroma database
|
40
|
-
db_config: database configuraiton
|
41
22
|
diagram_type: type of PLANTUML diagram to generate
|
23
|
+
add_documentation: Whether to add a documentation step prior to
|
24
|
+
diagram generation.
|
42
25
|
"""
|
43
|
-
super().__init__(**kwargs)
|
44
26
|
self._diagram_type = diagram_type
|
45
27
|
self._add_documentation = add_documentation
|
46
|
-
self._documenter =
|
47
|
-
self._diagram_parser = UMLSyntaxParser(language="plantuml")
|
48
|
-
if add_documentation:
|
49
|
-
self._diagram_prompt_template_name = "diagram_with_documentation"
|
50
|
-
else:
|
51
|
-
self._diagram_prompt_template_name = "diagram"
|
52
|
-
self._load_diagram_prompt_engine()
|
28
|
+
self._documenter = Documenter(**kwargs)
|
53
29
|
|
54
|
-
|
55
|
-
|
56
|
-
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
|
33
|
+
self._parser = UMLSyntaxParser(language="plantuml")
|
57
34
|
|
58
|
-
|
59
|
-
n2 = round((self.max_prompts // n1) ** (1 / 2))
|
35
|
+
self._load_parameters()
|
60
36
|
|
61
|
-
|
62
|
-
|
37
|
+
def _load_prompt(self):
|
38
|
+
super()._load_prompt()
|
39
|
+
self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
|
63
40
|
|
41
|
+
def _input_runnable(self) -> Runnable:
|
64
42
|
if self._add_documentation:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
**{
|
70
|
-
"SOURCE_CODE": input,
|
71
|
-
"DOCUMENTATION": documentation_text,
|
72
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
73
|
-
}
|
74
|
-
),
|
75
|
-
refiner=self._refiner,
|
76
|
-
max_retries=n1,
|
77
|
-
llm=self._llm,
|
78
|
-
)
|
79
|
-
else:
|
80
|
-
refine_output = RefinerParser(
|
81
|
-
parser=self._diagram_parser,
|
82
|
-
initial_prompt=self._diagram_prompt.format(
|
83
|
-
**{
|
84
|
-
"SOURCE_CODE": input,
|
85
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
86
|
-
}
|
87
|
-
),
|
88
|
-
refiner=self._refiner,
|
89
|
-
max_retries=n1,
|
90
|
-
llm=self._llm,
|
43
|
+
return RunnableParallel(
|
44
|
+
SOURCE_CODE=self._parser.parse_input,
|
45
|
+
DOCUMENTATION=self._documenter.chain,
|
46
|
+
context=self._retriever,
|
91
47
|
)
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
max_retries=n2,
|
96
|
-
)
|
97
|
-
completion_chain = self._prompt | self._llm
|
98
|
-
chain = RunnableParallel(
|
99
|
-
completion=completion_chain, prompt_value=self._diagram_prompt
|
100
|
-
) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
|
101
|
-
for _ in range(n3):
|
102
|
-
try:
|
103
|
-
if self._add_documentation:
|
104
|
-
return chain.invoke(
|
105
|
-
{
|
106
|
-
"SOURCE_CODE": input,
|
107
|
-
"DOCUMENTATION": documentation_text,
|
108
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
else:
|
112
|
-
return chain.invoke(
|
113
|
-
{
|
114
|
-
"SOURCE_CODE": input,
|
115
|
-
"DIAGRAM_TYPE": self._diagram_type,
|
116
|
-
}
|
117
|
-
)
|
118
|
-
except OutputParserException:
|
119
|
-
pass
|
120
|
-
|
121
|
-
raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
|
122
|
-
|
123
|
-
@run_if_changed(
|
124
|
-
"_diagram_prompt_template_name",
|
125
|
-
"_source_language",
|
126
|
-
)
|
127
|
-
def _load_diagram_prompt_engine(self) -> None:
|
128
|
-
"""Load the prompt engine according to this instance's attributes.
|
129
|
-
|
130
|
-
If the relevant fields have not been changed since the last time this method was
|
131
|
-
called, nothing happens.
|
132
|
-
"""
|
133
|
-
self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
|
134
|
-
source_language=self._source_language,
|
135
|
-
target_language="text",
|
136
|
-
target_version=None,
|
137
|
-
prompt_template=self._diagram_prompt_template_name,
|
48
|
+
return RunnableParallel(
|
49
|
+
SOURCE_CODE=self._parser.parse_input,
|
50
|
+
context=self._retriever,
|
138
51
|
)
|
139
|
-
self._diagram_prompt = self._diagram_prompt_engine.prompt
|
janus/converter/translate.py
CHANGED
@@ -90,7 +90,7 @@ class Translator(Converter):
|
|
90
90
|
f"({self._source_language} != {self._target_language})"
|
91
91
|
)
|
92
92
|
|
93
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
93
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
|
94
94
|
source_language=self._source_language,
|
95
95
|
target_language=self._target_language,
|
96
96
|
target_version=self._target_version,
|
@@ -12,7 +12,7 @@ class TestAlcSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = AlcSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="ibmhlasm")
|
18
18
|
self.test_file = Path("janus/language/alc/_tests/alc.asm")
|
janus/language/alc/alc.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
import re
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from langchain.schema.language_model import BaseLanguageModel
|
5
|
-
|
6
4
|
from janus.language.block import CodeBlock
|
7
5
|
from janus.language.combine import Combiner
|
8
6
|
from janus.language.node import NodeType
|
9
7
|
from janus.language.treesitter import TreeSitterSplitter
|
8
|
+
from janus.llm.models_info import JanusModel
|
10
9
|
from janus.utils.logger import create_logger
|
11
10
|
|
12
11
|
log = create_logger(__name__)
|
@@ -27,7 +26,7 @@ class AlcSplitter(TreeSitterSplitter):
|
|
27
26
|
|
28
27
|
def __init__(
|
29
28
|
self,
|
30
|
-
model:
|
29
|
+
model: JanusModel | None = None,
|
31
30
|
max_tokens: int = 4096,
|
32
31
|
protected_node_types: tuple[str, ...] = (),
|
33
32
|
prune_node_types: tuple[str, ...] = (),
|
@@ -101,7 +100,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
101
100
|
|
102
101
|
def __init__(
|
103
102
|
self,
|
104
|
-
model:
|
103
|
+
model: JanusModel | None = None,
|
105
104
|
max_tokens: int = 4096,
|
106
105
|
protected_node_types: tuple[str, ...] = (),
|
107
106
|
prune_node_types: tuple[str, ...] = (),
|
@@ -129,12 +128,18 @@ class AlcListingSplitter(AlcSplitter):
|
|
129
128
|
prune_unprotected=prune_unprotected,
|
130
129
|
)
|
131
130
|
|
132
|
-
def
|
131
|
+
def split_string(self, code: str, name: str) -> CodeBlock:
|
132
|
+
# Override split_string to use processed code and track active usings
|
133
133
|
active_usings = self.get_active_usings(code)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
134
|
+
processed_code = self.preproccess_assembly(code)
|
135
|
+
root = super().split_string(processed_code, name)
|
136
|
+
if active_usings is not None:
|
137
|
+
stack = [root]
|
138
|
+
while stack:
|
139
|
+
block = stack.pop()
|
140
|
+
block.context_tags["active_usings"] = active_usings
|
141
|
+
stack.extend(block.children)
|
142
|
+
return root
|
138
143
|
|
139
144
|
def preproccess_assembly(self, code: str) -> str:
|
140
145
|
"""Remove non-essential lines from an assembly snippet"""
|
@@ -142,7 +147,7 @@ class AlcListingSplitter(AlcSplitter):
|
|
142
147
|
lines = code.splitlines()
|
143
148
|
lines = self.strip_header_and_left(lines)
|
144
149
|
lines = self.strip_addresses(lines)
|
145
|
-
return "".join(str(line) for line in lines)
|
150
|
+
return "\n".join(str(line) for line in lines)
|
146
151
|
|
147
152
|
def get_active_usings(self, code: str) -> Optional[str]:
|
148
153
|
"""Look for 'active usings' in the ALC listing header"""
|
@@ -15,7 +15,7 @@ class TestBinarySplitter(unittest.TestCase):
|
|
15
15
|
def setUp(self):
|
16
16
|
model_name = "gpt-4o"
|
17
17
|
self.binary_file = Path("janus/language/binary/_tests/hello")
|
18
|
-
self.llm
|
18
|
+
self.llm = load_model(model_name)
|
19
19
|
self.splitter = BinarySplitter(model=self.llm)
|
20
20
|
os.environ["GHIDRA_INSTALL_PATH"] = "~/programs/ghidra_10.4_PUBLIC"
|
21
21
|
|
janus/language/binary/binary.py
CHANGED
@@ -5,11 +5,11 @@ import tempfile
|
|
5
5
|
from pathlib import Path
|
6
6
|
|
7
7
|
import tree_sitter
|
8
|
-
from langchain.schema.language_model import BaseLanguageModel
|
9
8
|
|
10
9
|
from janus.language.block import CodeBlock
|
11
10
|
from janus.language.combine import Combiner
|
12
11
|
from janus.language.treesitter import TreeSitterSplitter
|
12
|
+
from janus.llm.models_info import JanusModel
|
13
13
|
from janus.utils.enums import LANGUAGES
|
14
14
|
from janus.utils.logger import create_logger
|
15
15
|
|
@@ -31,7 +31,7 @@ class BinarySplitter(TreeSitterSplitter):
|
|
31
31
|
|
32
32
|
def __init__(
|
33
33
|
self,
|
34
|
-
model:
|
34
|
+
model: JanusModel | None = None,
|
35
35
|
max_tokens: int = 4096,
|
36
36
|
protected_node_types: tuple[str] = (),
|
37
37
|
prune_node_types: tuple[str] = (),
|
@@ -12,7 +12,7 @@ class TestMumpsSplitter(unittest.TestCase):
|
|
12
12
|
def setUp(self):
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
|
-
llm
|
15
|
+
llm = load_model(model_name)
|
16
16
|
self.splitter = MumpsSplitter(model=llm)
|
17
17
|
self.combiner = Combiner(language="mumps")
|
18
18
|
self.test_file = Path("janus/language/mumps/_tests/mumps.m")
|
janus/language/mumps/mumps.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
import re
|
2
2
|
|
3
|
-
from langchain.schema.language_model import BaseLanguageModel
|
4
|
-
|
5
3
|
from janus.language.block import CodeBlock
|
6
4
|
from janus.language.combine import Combiner
|
7
5
|
from janus.language.node import NodeType
|
8
6
|
from janus.language.splitter import Splitter
|
7
|
+
from janus.llm.models_info import JanusModel
|
9
8
|
from janus.utils.logger import create_logger
|
10
9
|
|
11
10
|
log = create_logger(__name__)
|
@@ -44,7 +43,7 @@ class MumpsSplitter(Splitter):
|
|
44
43
|
|
45
44
|
def __init__(
|
46
45
|
self,
|
47
|
-
model:
|
46
|
+
model: JanusModel | None = None,
|
48
47
|
max_tokens: int = 4096,
|
49
48
|
protected_node_types: tuple[str] = ("routine_definition",),
|
50
49
|
prune_node_types: tuple[str] = (),
|
janus/language/splitter.py
CHANGED
@@ -4,11 +4,11 @@ from pathlib import Path
|
|
4
4
|
from typing import List
|
5
5
|
|
6
6
|
import tiktoken
|
7
|
-
from langchain.schema.language_model import BaseLanguageModel
|
8
7
|
|
9
8
|
from janus.language.block import CodeBlock
|
10
9
|
from janus.language.file import FileManager
|
11
10
|
from janus.language.node import NodeType
|
11
|
+
from janus.llm.models_info import JanusModel
|
12
12
|
from janus.utils.logger import create_logger
|
13
13
|
|
14
14
|
log = create_logger(__name__)
|
@@ -44,7 +44,7 @@ class Splitter(FileManager):
|
|
44
44
|
def __init__(
|
45
45
|
self,
|
46
46
|
language: str,
|
47
|
-
model:
|
47
|
+
model: JanusModel | None = None,
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
skip_merge: bool = False,
|
50
50
|
protected_node_types: tuple[str, ...] = (),
|
@@ -13,7 +13,7 @@ class TestTreeSitterSplitter(unittest.TestCase):
|
|
13
13
|
"""Set up the tests."""
|
14
14
|
model_name = "gpt-4o"
|
15
15
|
self.maxDiff = None
|
16
|
-
self.llm
|
16
|
+
self.llm = load_model(model_name)
|
17
17
|
|
18
18
|
def _split(self):
|
19
19
|
"""Split the test file."""
|
@@ -7,10 +7,10 @@ from typing import Optional
|
|
7
7
|
|
8
8
|
import tree_sitter
|
9
9
|
from git import Repo
|
10
|
-
from langchain.schema.language_model import BaseLanguageModel
|
11
10
|
|
12
11
|
from janus.language.block import CodeBlock, NodeType
|
13
12
|
from janus.language.splitter import Splitter
|
13
|
+
from janus.llm.models_info import JanusModel
|
14
14
|
from janus.utils.enums import LANGUAGES
|
15
15
|
from janus.utils.logger import create_logger
|
16
16
|
|
@@ -25,7 +25,7 @@ class TreeSitterSplitter(Splitter):
|
|
25
25
|
def __init__(
|
26
26
|
self,
|
27
27
|
language: str,
|
28
|
-
model:
|
28
|
+
model: JanusModel | None = None,
|
29
29
|
max_tokens: int = 4096,
|
30
30
|
protected_node_types: tuple[str, ...] = (),
|
31
31
|
prune_node_types: tuple[str, ...] = (),
|
janus/llm/model_callbacks.py
CHANGED
@@ -12,6 +12,17 @@ from janus.utils.logger import create_logger
|
|
12
12
|
|
13
13
|
log = create_logger(__name__)
|
14
14
|
|
15
|
+
openai_model_reroutes = {
|
16
|
+
"gpt-4o": "gpt-4o-2024-05-13",
|
17
|
+
"gpt-4o-mini": "gpt-4o-mini",
|
18
|
+
"gpt-4": "gpt-4-0613",
|
19
|
+
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
20
|
+
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
21
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
22
|
+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
23
|
+
"gpt-3.5-turbo-16k-0613": "gpt-3.5-turbo-0125",
|
24
|
+
}
|
25
|
+
|
15
26
|
|
16
27
|
# Updated 2024-06-21
|
17
28
|
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
@@ -45,6 +56,8 @@ def _get_token_cost(
|
|
45
56
|
prompt_tokens: int, completion_tokens: int, model_id: str | None
|
46
57
|
) -> float:
|
47
58
|
"""Get the cost of tokens according to model ID"""
|
59
|
+
if model_id in openai_model_reroutes:
|
60
|
+
model_id = openai_model_reroutes[model_id]
|
48
61
|
if model_id not in COST_PER_1K_TOKENS:
|
49
62
|
raise ValueError(
|
50
63
|
f"Unknown model: {model_id}. Please provide a valid model name."
|
janus/llm/models_info.py
CHANGED
@@ -2,14 +2,14 @@ import json
|
|
2
2
|
import os
|
3
3
|
import time
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import
|
5
|
+
from typing import Protocol, TypeVar
|
6
6
|
|
7
7
|
from dotenv import load_dotenv
|
8
8
|
from langchain_community.llms import HuggingFaceTextGenInference
|
9
|
-
from langchain_core.
|
9
|
+
from langchain_core.runnables import Runnable
|
10
10
|
from langchain_openai import ChatOpenAI
|
11
11
|
|
12
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
12
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, openai_model_reroutes
|
13
13
|
from janus.prompts.prompt import (
|
14
14
|
ChatGptPromptEngine,
|
15
15
|
ClaudePromptEngine,
|
@@ -44,17 +44,34 @@ except ImportError:
|
|
44
44
|
)
|
45
45
|
|
46
46
|
|
47
|
+
ModelType = TypeVar(
|
48
|
+
"ModelType",
|
49
|
+
ChatOpenAI,
|
50
|
+
HuggingFaceTextGenInference,
|
51
|
+
Bedrock,
|
52
|
+
BedrockChat,
|
53
|
+
HuggingFacePipeline,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class JanusModelProtocol(Protocol):
|
58
|
+
model_id: str
|
59
|
+
model_type_name: str
|
60
|
+
token_limit: int
|
61
|
+
input_token_cost: float
|
62
|
+
output_token_cost: float
|
63
|
+
prompt_engine: type[PromptEngine]
|
64
|
+
|
65
|
+
def get_num_tokens(self, text: str) -> int:
|
66
|
+
...
|
67
|
+
|
68
|
+
|
69
|
+
class JanusModel(Runnable, JanusModelProtocol):
|
70
|
+
...
|
71
|
+
|
72
|
+
|
47
73
|
load_dotenv()
|
48
74
|
|
49
|
-
openai_model_reroutes = {
|
50
|
-
"gpt-4o": "gpt-4o-2024-05-13",
|
51
|
-
"gpt-4o-mini": "gpt-4o-mini",
|
52
|
-
"gpt-4": "gpt-4-0613",
|
53
|
-
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
54
|
-
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
55
|
-
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
56
|
-
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
57
|
-
}
|
58
75
|
|
59
76
|
openai_models = [
|
60
77
|
"gpt-4o",
|
@@ -105,24 +122,15 @@ bedrock_models = [
|
|
105
122
|
]
|
106
123
|
all_models = [*openai_models, *bedrock_models]
|
107
124
|
|
108
|
-
MODEL_TYPE_CONSTRUCTORS: dict[str,
|
125
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
109
126
|
"OpenAI": ChatOpenAI,
|
110
127
|
"HuggingFace": HuggingFaceTextGenInference,
|
128
|
+
"Bedrock": Bedrock,
|
129
|
+
"BedrockChat": BedrockChat,
|
130
|
+
"HuggingFaceLocal": HuggingFacePipeline,
|
111
131
|
}
|
112
132
|
|
113
|
-
|
114
|
-
MODEL_TYPE_CONSTRUCTORS.update(
|
115
|
-
{
|
116
|
-
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
117
|
-
"Bedrock": Bedrock,
|
118
|
-
"BedrockChat": BedrockChat,
|
119
|
-
}
|
120
|
-
)
|
121
|
-
except NameError:
|
122
|
-
pass
|
123
|
-
|
124
|
-
|
125
|
-
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
133
|
+
MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
|
126
134
|
**{m: ChatGptPromptEngine for m in openai_models},
|
127
135
|
**{m: ClaudePromptEngine for m in claude_models},
|
128
136
|
**{m: Llama2PromptEngine for m in llama2_models},
|
@@ -132,11 +140,6 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
132
140
|
**{m: MistralPromptEngine for m in mistral_models},
|
133
141
|
}
|
134
142
|
|
135
|
-
_open_ai_defaults: dict[str, str] = {
|
136
|
-
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
137
|
-
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
138
|
-
}
|
139
|
-
|
140
143
|
MODEL_ID_TO_LONG_ID = {
|
141
144
|
**{m: mr for m, mr in openai_model_reroutes.items()},
|
142
145
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
@@ -168,7 +171,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
168
171
|
|
169
172
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
170
173
|
|
171
|
-
MODEL_TYPES: dict[str,
|
174
|
+
MODEL_TYPES: dict[str, str] = {
|
172
175
|
**{m: "OpenAI" for m in openai_models},
|
173
176
|
**{m: "BedrockChat" for m in bedrock_models},
|
174
177
|
}
|
@@ -211,53 +214,90 @@ def get_available_model_names() -> list[str]:
|
|
211
214
|
return avaialable_models
|
212
215
|
|
213
216
|
|
214
|
-
def load_model(
|
215
|
-
user_model_name: str,
|
216
|
-
) -> tuple[BaseLanguageModel, str, int, dict[str, float]]:
|
217
|
+
def load_model(model_id) -> JanusModel:
|
217
218
|
if not MODEL_CONFIG_DIR.exists():
|
218
219
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
219
|
-
model_config_file = MODEL_CONFIG_DIR / f"{
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
f"default models for {user_model_name}."
|
224
|
-
)
|
225
|
-
model_id = user_model_name
|
226
|
-
if user_model_name not in DEFAULT_MODELS:
|
227
|
-
message = (
|
228
|
-
f"Model {user_model_name} not found in default models. Make sure to run "
|
229
|
-
"`janus llm add` first."
|
230
|
-
)
|
231
|
-
log.error(message)
|
232
|
-
raise ValueError(message)
|
233
|
-
model_config = {
|
234
|
-
"model_type": MODEL_TYPES[model_id],
|
235
|
-
"model_id": model_id,
|
236
|
-
"model_args": MODEL_DEFAULT_ARGUMENTS[model_id],
|
237
|
-
"token_limit": TOKEN_LIMITS.get(MODEL_ID_TO_LONG_ID[model_id], 4096),
|
238
|
-
"model_cost": COST_PER_1K_TOKENS.get(
|
239
|
-
MODEL_ID_TO_LONG_ID[model_id], {"input": 0, "output": 0}
|
240
|
-
),
|
241
|
-
}
|
242
|
-
with open(model_config_file, "w") as f:
|
243
|
-
json.dump(model_config, f)
|
244
|
-
else:
|
220
|
+
model_config_file = MODEL_CONFIG_DIR / f"{model_id}.json"
|
221
|
+
|
222
|
+
if model_config_file.exists():
|
223
|
+
log.info(f"Loading {model_id} from {model_config_file}.")
|
245
224
|
with open(model_config_file, "r") as f:
|
246
225
|
model_config = json.load(f)
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
226
|
+
model_type_name = model_config["model_type"]
|
227
|
+
model_id = model_config["model_id"]
|
228
|
+
model_args = model_config["model_args"]
|
229
|
+
token_limit = model_config["token_limit"]
|
230
|
+
input_token_cost = model_config["model_cost"]["input"]
|
231
|
+
output_token_cost = model_config["model_cost"]["output"]
|
232
|
+
|
233
|
+
elif model_id in DEFAULT_MODELS:
|
234
|
+
model_id = model_id
|
235
|
+
model_long_id = MODEL_ID_TO_LONG_ID[model_id]
|
236
|
+
model_type_name = MODEL_TYPES[model_id]
|
237
|
+
model_args = MODEL_DEFAULT_ARGUMENTS[model_id]
|
238
|
+
|
239
|
+
token_limit = 0
|
240
|
+
input_token_cost = 0.0
|
241
|
+
output_token_cost = 0.0
|
242
|
+
if model_long_id in TOKEN_LIMITS:
|
243
|
+
token_limit = TOKEN_LIMITS[model_long_id]
|
244
|
+
if model_long_id in COST_PER_1K_TOKENS:
|
245
|
+
token_limits = COST_PER_1K_TOKENS[model_long_id]
|
246
|
+
input_token_cost = token_limits["input"]
|
247
|
+
output_token_cost = token_limits["output"]
|
248
|
+
|
249
|
+
else:
|
250
|
+
model_list = "\n\t".join(DEFAULT_MODELS)
|
251
|
+
message = (
|
252
|
+
f"Model {model_id} not found in user-defined model directory "
|
253
|
+
f"({MODEL_CONFIG_DIR}), and is not a default model. Valid default "
|
254
|
+
f"models:\n\t{model_list}\n"
|
255
|
+
f"To use a custom model, first run `janus llm add`."
|
256
|
+
)
|
257
|
+
log.error(message)
|
258
|
+
raise ValueError(message)
|
259
|
+
|
260
|
+
if model_type_name == "HuggingFaceLocal":
|
261
|
+
model = HuggingFacePipeline.from_model_id(
|
262
|
+
model_id=model_id,
|
263
|
+
task="text-generation",
|
264
|
+
model_kwargs=model_args,
|
265
|
+
)
|
266
|
+
model_args.update(pipeline=model.pipeline)
|
267
|
+
|
268
|
+
elif model_type_name == "OpenAI":
|
269
|
+
model_args.update(
|
270
|
+
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
271
|
+
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
272
|
+
)
|
251
273
|
log.warning("Do NOT use this model in sensitive environments!")
|
252
274
|
log.warning("If you would like to cancel, please press Ctrl+C.")
|
253
275
|
log.warning("Waiting 10 seconds...")
|
254
276
|
# Give enough time for the user to read the warnings and cancel
|
255
277
|
time.sleep(10)
|
256
278
|
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
279
|
+
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
280
|
+
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
281
|
+
|
282
|
+
class JanusModel(model_type):
|
283
|
+
model_id: str
|
284
|
+
short_model_id: str
|
285
|
+
model_type_name: str
|
286
|
+
token_limit: int
|
287
|
+
input_token_cost: float
|
288
|
+
output_token_cost: float
|
289
|
+
prompt_engine: type[PromptEngine]
|
290
|
+
|
291
|
+
model_args.update(
|
292
|
+
model_id=MODEL_ID_TO_LONG_ID[model_id],
|
293
|
+
short_model_id=model_id,
|
294
|
+
)
|
295
|
+
|
296
|
+
return JanusModel(
|
297
|
+
model_type_name=model_type_name,
|
298
|
+
token_limit=token_limit,
|
299
|
+
input_token_cost=input_token_cost,
|
300
|
+
output_token_cost=output_token_cost,
|
301
|
+
prompt_engine=prompt_engine,
|
302
|
+
**model_args,
|
263
303
|
)
|