janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -50
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +51 -0
- janus/converter/translate.py +108 -0
- janus/language/block.py +1 -1
- janus/language/combine.py +0 -1
- janus/language/treesitter/treesitter.py +20 -1
- janus/llm/model_callbacks.py +33 -36
- janus/llm/models_info.py +14 -0
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +37 -11
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/METADATA +1 -1
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/RECORD +25 -19
- janus/converter.py +0 -161
- janus/translate.py +0 -987
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/WHEEL +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,124 @@
|
|
1
|
+
import json
|
2
|
+
from copy import deepcopy
|
3
|
+
|
4
|
+
from janus.converter.converter import run_if_changed
|
5
|
+
from janus.converter.document import Documenter
|
6
|
+
from janus.language.block import TranslatedCodeBlock
|
7
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
8
|
+
from janus.utils.logger import create_logger
|
9
|
+
|
10
|
+
log = create_logger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class DiagramGenerator(Documenter):
|
14
|
+
"""DiagramGenerator
|
15
|
+
|
16
|
+
A class that translates code from one programming language to a set of diagrams.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
diagram_type="Activity",
|
22
|
+
add_documentation=False,
|
23
|
+
**kwargs,
|
24
|
+
) -> None:
|
25
|
+
"""Initialize the DiagramGenerator class
|
26
|
+
|
27
|
+
Arguments:
|
28
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
29
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
30
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
31
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
32
|
+
source_language: The source programming language.
|
33
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
34
|
+
db_path: path to chroma database
|
35
|
+
db_config: database configuraiton
|
36
|
+
diagram_type: type of PLANTUML diagram to generate
|
37
|
+
"""
|
38
|
+
super().__init__(**kwargs)
|
39
|
+
self._diagram_type = diagram_type
|
40
|
+
self._add_documentation = add_documentation
|
41
|
+
self._documenter = None
|
42
|
+
if add_documentation:
|
43
|
+
self._diagram_prompt_template_name = "diagram_with_documentation"
|
44
|
+
else:
|
45
|
+
self._diagram_prompt_template_name = "diagram"
|
46
|
+
self._load_diagram_prompt_engine()
|
47
|
+
|
48
|
+
def _add_translation(self, block: TranslatedCodeBlock) -> None:
|
49
|
+
"""Given an "empty" `TranslatedCodeBlock`, translate the code represented in
|
50
|
+
`block.original`, setting the relevant fields in the translated block. The
|
51
|
+
`TranslatedCodeBlock` is updated in-pace, nothing is returned. Note that this
|
52
|
+
translates *only* the code for this block, not its children.
|
53
|
+
|
54
|
+
Arguments:
|
55
|
+
block: An empty `TranslatedCodeBlock`
|
56
|
+
"""
|
57
|
+
if block.translated:
|
58
|
+
return
|
59
|
+
|
60
|
+
if block.original.text is None:
|
61
|
+
block.translated = True
|
62
|
+
return
|
63
|
+
|
64
|
+
if self._add_documentation:
|
65
|
+
documentation_block = deepcopy(block)
|
66
|
+
super()._add_translation(documentation_block)
|
67
|
+
if not documentation_block.translated:
|
68
|
+
message = "Error: unable to produce documentation for code block"
|
69
|
+
log.info(message)
|
70
|
+
raise ValueError(message)
|
71
|
+
documentation = json.loads(documentation_block.text)["docstring"]
|
72
|
+
|
73
|
+
if self._llm is None:
|
74
|
+
message = (
|
75
|
+
"Model not configured correctly, cannot translate. Try setting "
|
76
|
+
"the model"
|
77
|
+
)
|
78
|
+
log.error(message)
|
79
|
+
raise ValueError(message)
|
80
|
+
|
81
|
+
log.debug(f"[{block.name}] Translating...")
|
82
|
+
log.debug(f"[{block.name}] Input text:\n{block.original.text}")
|
83
|
+
|
84
|
+
self._parser.set_reference(block.original)
|
85
|
+
|
86
|
+
query_and_parse = self.diagram_prompt | self._llm | self._parser
|
87
|
+
|
88
|
+
if self._add_documentation:
|
89
|
+
block.text = query_and_parse.invoke(
|
90
|
+
{
|
91
|
+
"SOURCE_CODE": block.original.text,
|
92
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
93
|
+
"DOCUMENTATION": documentation,
|
94
|
+
}
|
95
|
+
)
|
96
|
+
else:
|
97
|
+
block.text = query_and_parse.invoke(
|
98
|
+
{
|
99
|
+
"SOURCE_CODE": block.original.text,
|
100
|
+
"DIAGRAM_TYPE": self._diagram_type,
|
101
|
+
}
|
102
|
+
)
|
103
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
104
|
+
block.translated = True
|
105
|
+
|
106
|
+
log.debug(f"[{block.name}] Output code:\n{block.text}")
|
107
|
+
|
108
|
+
@run_if_changed(
|
109
|
+
"_diagram_prompt_template_name",
|
110
|
+
"_source_language",
|
111
|
+
)
|
112
|
+
def _load_diagram_prompt_engine(self) -> None:
|
113
|
+
"""Load the prompt engine according to this instance's attributes.
|
114
|
+
|
115
|
+
If the relevant fields have not been changed since the last time this method was
|
116
|
+
called, nothing happens.
|
117
|
+
"""
|
118
|
+
self._diagram_prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
|
119
|
+
source_language=self._source_language,
|
120
|
+
target_language="text",
|
121
|
+
target_version=None,
|
122
|
+
prompt_template=self._diagram_prompt_template_name,
|
123
|
+
)
|
124
|
+
self.diagram_prompt = self._diagram_prompt_engine.prompt
|
@@ -0,0 +1,131 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
from janus.converter.converter import Converter
|
6
|
+
from janus.language.block import TranslatedCodeBlock
|
7
|
+
from janus.language.combine import JsonCombiner
|
8
|
+
from janus.parsers.doc_parser import (
|
9
|
+
MadlibsDocumentationParser,
|
10
|
+
MultiDocumentationParser,
|
11
|
+
)
|
12
|
+
from janus.utils.enums import LANGUAGES
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class Documenter(Converter):
|
19
|
+
def __init__(
|
20
|
+
self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
|
21
|
+
):
|
22
|
+
kwargs.update(source_language=source_language)
|
23
|
+
super().__init__(**kwargs)
|
24
|
+
self.set_prompt("document")
|
25
|
+
|
26
|
+
if drop_comments:
|
27
|
+
comment_node_type = LANGUAGES[source_language].get(
|
28
|
+
"comment_node_type", "comment"
|
29
|
+
)
|
30
|
+
self.set_prune_node_types((comment_node_type,))
|
31
|
+
|
32
|
+
self._load_parameters()
|
33
|
+
|
34
|
+
|
35
|
+
class MultiDocumenter(Documenter):
|
36
|
+
def __init__(self, **kwargs):
|
37
|
+
super().__init__(**kwargs)
|
38
|
+
self.set_prompt("multidocument")
|
39
|
+
self._combiner = JsonCombiner()
|
40
|
+
self._parser = MultiDocumentationParser()
|
41
|
+
|
42
|
+
|
43
|
+
class MadLibsDocumenter(Documenter):
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
comments_per_request: int | None = None,
|
47
|
+
**kwargs,
|
48
|
+
) -> None:
|
49
|
+
kwargs.update(drop_comments=False)
|
50
|
+
super().__init__(**kwargs)
|
51
|
+
self.set_prompt("document_madlibs")
|
52
|
+
self._combiner = JsonCombiner()
|
53
|
+
self._parser = MadlibsDocumentationParser()
|
54
|
+
|
55
|
+
self.comments_per_request = comments_per_request
|
56
|
+
|
57
|
+
def _add_translation(self, block: TranslatedCodeBlock):
|
58
|
+
if block.translated:
|
59
|
+
return
|
60
|
+
|
61
|
+
if block.original.text is None:
|
62
|
+
block.translated = True
|
63
|
+
return
|
64
|
+
|
65
|
+
if self.comments_per_request is None:
|
66
|
+
return super()._add_translation(block)
|
67
|
+
|
68
|
+
comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>"
|
69
|
+
comments = list(
|
70
|
+
re.finditer(
|
71
|
+
comment_pattern,
|
72
|
+
block.original.text,
|
73
|
+
)
|
74
|
+
)
|
75
|
+
|
76
|
+
if not comments:
|
77
|
+
log.info(f"[{block.name}] Skipping commentless block")
|
78
|
+
block.translated = True
|
79
|
+
block.text = None
|
80
|
+
block.complete = True
|
81
|
+
return
|
82
|
+
|
83
|
+
if len(comments) <= self.comments_per_request:
|
84
|
+
return super()._add_translation(block)
|
85
|
+
|
86
|
+
comment_group_indices = list(range(0, len(comments), self.comments_per_request))
|
87
|
+
log.debug(
|
88
|
+
f"[{block.name}] Block contains more than {self.comments_per_request}"
|
89
|
+
f" comments, splitting {len(comments)} comments into"
|
90
|
+
f" {len(comment_group_indices)} groups"
|
91
|
+
)
|
92
|
+
|
93
|
+
block.processing_time = 0
|
94
|
+
block.cost = 0
|
95
|
+
block.retries = 0
|
96
|
+
obj = {}
|
97
|
+
for i in range(0, len(comments), self.comments_per_request):
|
98
|
+
# Split the text into the section containing comments of interest,
|
99
|
+
# all the text prior to those comments, and all the text after them
|
100
|
+
working_comments = comments[i : i + self.comments_per_request]
|
101
|
+
start_idx = working_comments[0].start()
|
102
|
+
end_idx = working_comments[-1].end()
|
103
|
+
prefix = block.original.text[:start_idx]
|
104
|
+
keeper = block.original.text[start_idx:end_idx]
|
105
|
+
suffix = block.original.text[end_idx:]
|
106
|
+
|
107
|
+
# Strip all comment placeholders outside of the section of interest
|
108
|
+
prefix = re.sub(comment_pattern, "", prefix)
|
109
|
+
suffix = re.sub(comment_pattern, "", suffix)
|
110
|
+
|
111
|
+
# Build a new TranslatedBlock using the new working text
|
112
|
+
working_copy = deepcopy(block.original)
|
113
|
+
working_copy.text = prefix + keeper + suffix
|
114
|
+
working_block = TranslatedCodeBlock(working_copy, self._target_language)
|
115
|
+
|
116
|
+
# Run the LLM on the working text
|
117
|
+
super()._add_translation(working_block)
|
118
|
+
|
119
|
+
# Update metadata to include for all runs
|
120
|
+
block.retries += working_block.retries
|
121
|
+
block.cost += working_block.cost
|
122
|
+
block.processing_time += working_block.processing_time
|
123
|
+
|
124
|
+
# Update the output text to merge this section's output in
|
125
|
+
out_text = self._parser.parse(working_block.text)
|
126
|
+
obj.update(json.loads(out_text))
|
127
|
+
|
128
|
+
self._parser.set_reference(block.original)
|
129
|
+
block.text = self._parser.parse(json.dumps(obj))
|
130
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
131
|
+
block.translated = True
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from janus.converter.converter import Converter
|
2
|
+
from janus.language.combine import JsonCombiner
|
3
|
+
from janus.parsers.eval_parser import EvaluationParser
|
4
|
+
from janus.utils.logger import create_logger
|
5
|
+
|
6
|
+
log = create_logger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
class Evaluator(Converter):
|
10
|
+
def __init__(self, **kwargs):
|
11
|
+
super().__init__(**kwargs)
|
12
|
+
self.set_prompt("evaluate")
|
13
|
+
self._combiner = JsonCombiner()
|
14
|
+
self._parser = EvaluationParser()
|
15
|
+
self._load_parameters()
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import json
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from janus.converter.document import Documenter
|
5
|
+
from janus.language.block import TranslatedCodeBlock
|
6
|
+
from janus.language.combine import ChunkCombiner
|
7
|
+
from janus.parsers.reqs_parser import RequirementsParser
|
8
|
+
from janus.utils.logger import create_logger
|
9
|
+
|
10
|
+
log = create_logger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class RequirementsDocumenter(Documenter):
|
14
|
+
"""RequirementsGenerator
|
15
|
+
|
16
|
+
A class that translates code from one programming language to its requirements.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, **kwargs):
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
self.set_prompt("requirements")
|
22
|
+
self._combiner = ChunkCombiner()
|
23
|
+
self._parser = RequirementsParser()
|
24
|
+
|
25
|
+
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
26
|
+
"""Save a file to disk.
|
27
|
+
|
28
|
+
Arguments:
|
29
|
+
block: The `CodeBlock` to save to a file.
|
30
|
+
"""
|
31
|
+
output_list = list()
|
32
|
+
# For each chunk of code, get generation metadata, the text of the code,
|
33
|
+
# and the LLM generated requirements
|
34
|
+
blocks = [block for block in block.children] if len(block.children) else [block]
|
35
|
+
for block in blocks:
|
36
|
+
code = block.original.text
|
37
|
+
requirements = self._parser.parse_combined_output(block.complete_text)
|
38
|
+
metadata = dict(
|
39
|
+
retries=block.total_retries,
|
40
|
+
cost=block.total_cost,
|
41
|
+
processing_time=block.processing_time,
|
42
|
+
)
|
43
|
+
# Put them all in a top level 'output' key
|
44
|
+
output_list.append(
|
45
|
+
dict(metadata=metadata, code=code, requirements=requirements)
|
46
|
+
)
|
47
|
+
obj = dict(
|
48
|
+
output=output_list,
|
49
|
+
)
|
50
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
51
|
+
out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
@@ -0,0 +1,108 @@
|
|
1
|
+
from janus.converter.converter import Converter, run_if_changed
|
2
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
3
|
+
from janus.parsers.code_parser import CodeParser
|
4
|
+
from janus.prompts.prompt import SAME_OUTPUT
|
5
|
+
from janus.utils.enums import LANGUAGES
|
6
|
+
from janus.utils.logger import create_logger
|
7
|
+
|
8
|
+
log = create_logger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class Translator(Converter):
|
12
|
+
"""A class that translates code from one programming language to another."""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
target_language: str = "python",
|
17
|
+
target_version: str | None = "3.10",
|
18
|
+
**kwargs,
|
19
|
+
) -> None:
|
20
|
+
"""Initialize a Translator instance.
|
21
|
+
|
22
|
+
Arguments:
|
23
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
24
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
25
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
26
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
27
|
+
source_language: The source programming language.
|
28
|
+
target_language: The target programming language.
|
29
|
+
target_version: The target version of the target programming language.
|
30
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
31
|
+
max_tokens: The maximum number of tokens the model will take in.
|
32
|
+
If unspecificed, model's default max will be used.
|
33
|
+
prompt_template: name of prompt template directory
|
34
|
+
(see janus/prompts/templates) or path to a directory.
|
35
|
+
"""
|
36
|
+
super().__init__(**kwargs)
|
37
|
+
|
38
|
+
self._target_version: str | None
|
39
|
+
|
40
|
+
self.set_target_language(
|
41
|
+
target_language=target_language,
|
42
|
+
target_version=target_version,
|
43
|
+
)
|
44
|
+
|
45
|
+
self._load_parameters()
|
46
|
+
|
47
|
+
def _load_parameters(self) -> None:
|
48
|
+
self._load_parser()
|
49
|
+
super()._load_parameters()
|
50
|
+
|
51
|
+
def set_target_language(
|
52
|
+
self, target_language: str, target_version: str | None
|
53
|
+
) -> None:
|
54
|
+
"""Validate and set the target language.
|
55
|
+
|
56
|
+
The affected objects will not be updated until translate() is called.
|
57
|
+
|
58
|
+
Arguments:
|
59
|
+
target_language: The target programming language.
|
60
|
+
target_version: The target version of the target programming language.
|
61
|
+
"""
|
62
|
+
target_language = target_language.lower()
|
63
|
+
if target_language not in LANGUAGES:
|
64
|
+
raise ValueError(
|
65
|
+
f"Invalid target language: {target_language}. "
|
66
|
+
"Valid target languages are found in `janus.utils.enums.LANGUAGES`."
|
67
|
+
)
|
68
|
+
self._target_language = target_language
|
69
|
+
self._target_version = target_version
|
70
|
+
self._target_suffix = f".{LANGUAGES[target_language]['suffix']}"
|
71
|
+
|
72
|
+
@run_if_changed(
|
73
|
+
"_prompt_template_name",
|
74
|
+
"_source_language",
|
75
|
+
"_target_language",
|
76
|
+
"_target_version",
|
77
|
+
"_model_name",
|
78
|
+
)
|
79
|
+
def _load_prompt(self) -> None:
|
80
|
+
"""Load the prompt according to this instance's attributes.
|
81
|
+
|
82
|
+
If the relevant fields have not been changed since the last time this
|
83
|
+
method was called, nothing happens.
|
84
|
+
"""
|
85
|
+
if self._prompt_template_name in SAME_OUTPUT:
|
86
|
+
if self._target_language != self._source_language:
|
87
|
+
raise ValueError(
|
88
|
+
f"Prompt template ({self._prompt_template_name}) suggests "
|
89
|
+
f"source and target languages should match, but do not "
|
90
|
+
f"({self._source_language} != {self._target_language})"
|
91
|
+
)
|
92
|
+
|
93
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
|
94
|
+
source_language=self._source_language,
|
95
|
+
target_language=self._target_language,
|
96
|
+
target_version=self._target_version,
|
97
|
+
prompt_template=self._prompt_template_name,
|
98
|
+
)
|
99
|
+
self._prompt = prompt_engine.prompt
|
100
|
+
|
101
|
+
@run_if_changed("_target_language")
|
102
|
+
def _load_parser(self) -> None:
|
103
|
+
"""Load the parser according to this instance's attributes.
|
104
|
+
|
105
|
+
If the relevant fields have not been changed since the last time this
|
106
|
+
method was called, nothing happens.
|
107
|
+
"""
|
108
|
+
self._parser = CodeParser(language=self._target_language)
|
janus/language/block.py
CHANGED
janus/language/combine.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import platform
|
3
3
|
from collections import defaultdict
|
4
|
+
from ctypes import c_void_p, cdll
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Optional
|
6
7
|
|
@@ -138,7 +139,25 @@ class TreeSitterSplitter(Splitter):
|
|
138
139
|
|
139
140
|
# Load the parser using the generated .so file
|
140
141
|
self.parser: tree_sitter.Parser = tree_sitter.Parser()
|
141
|
-
self.
|
142
|
+
pointer = self._so_to_pointer(so_file)
|
143
|
+
self.parser.set_language(tree_sitter.Language(pointer, self.language))
|
144
|
+
|
145
|
+
def _so_to_pointer(self, so_file: str) -> int:
|
146
|
+
"""Convert the .so file to a pointer.
|
147
|
+
|
148
|
+
Taken from `treesitter.Language.__init__` to get past deprecated warning.
|
149
|
+
|
150
|
+
Arguments:
|
151
|
+
so_file: The path to the so file for the language.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
The pointer to the language.
|
155
|
+
"""
|
156
|
+
lib = cdll.LoadLibrary(os.fspath(so_file))
|
157
|
+
language_function = getattr(lib, f"tree_sitter_{self.language}")
|
158
|
+
language_function.restype = c_void_p
|
159
|
+
pointer = language_function()
|
160
|
+
return pointer
|
142
161
|
|
143
162
|
def _create_parser(self, so_file: Path | str) -> None:
|
144
163
|
"""Create the parser for the given language.
|
janus/llm/model_callbacks.py
CHANGED
@@ -35,6 +35,9 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
|
35
35
|
"ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
|
36
36
|
"ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
|
37
37
|
"cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
|
38
|
+
"mistral.mistral-7b-instruct-v0:2": {"input": 0.00015, "output": 0.0002},
|
39
|
+
"mistral.mixtral-8x7b-instruct-v0:1": {"input": 0.00045, "output": 0.0007},
|
40
|
+
"mistral.mistral-large-2402-v1:0": {"input": 0.004, "output": 0.012},
|
38
41
|
}
|
39
42
|
|
40
43
|
|
@@ -103,53 +106,47 @@ class TokenUsageCallbackHandler(BaseCallbackHandler):
|
|
103
106
|
generation = response.generations[0][0]
|
104
107
|
except IndexError:
|
105
108
|
generation = None
|
106
|
-
if isinstance(generation, ChatGeneration):
|
107
|
-
try:
|
108
|
-
message = generation.message
|
109
|
-
if isinstance(message, AIMessage):
|
110
|
-
usage_metadata = message.usage_metadata
|
111
|
-
else:
|
112
|
-
usage_metadata = None
|
113
|
-
except AttributeError:
|
114
|
-
usage_metadata = None
|
115
|
-
else:
|
116
|
-
usage_metadata = None
|
117
|
-
if usage_metadata:
|
118
|
-
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
119
|
-
completion_tokens = usage_metadata["output_tokens"]
|
120
|
-
prompt_tokens = usage_metadata["input_tokens"]
|
121
|
-
if response.llm_output is None:
|
122
|
-
# model name (and therefore cost) is unavailable in
|
123
|
-
# streaming responses
|
124
|
-
model_name = ""
|
125
|
-
else:
|
126
|
-
model_name = response.llm_output.get("model_name", "")
|
127
109
|
|
110
|
+
model_id = ""
|
111
|
+
usage_metadata = None
|
112
|
+
if hasattr(response, "llm_output") and response.llm_output is not None:
|
113
|
+
model_id = response.llm_output.get("model_id", model_id)
|
114
|
+
model_id = response.llm_output.get("model_name", model_id)
|
115
|
+
usage_metadata = response.llm_output.get("usage", usage_metadata)
|
116
|
+
usage_metadata = response.llm_output.get("token_usage", usage_metadata)
|
117
|
+
elif isinstance(generation, ChatGeneration):
|
118
|
+
if hasattr(generation, "response_metadata"):
|
119
|
+
model_id = generation.response_metadata.get("model_id", model_id)
|
120
|
+
model_id = generation.response_metadata.get("model_name", model_id)
|
121
|
+
usage_metadata = generation.response_metadata.get("usage", usage_metadata)
|
122
|
+
elif hasattr(generation, "message"):
|
123
|
+
if isinstance(generation.message, AIMessage):
|
124
|
+
usage_metadata = generation.message.usage_metadata
|
125
|
+
|
126
|
+
completion_tokens = 0
|
127
|
+
prompt_tokens = 0
|
128
|
+
total_tokens = 0
|
129
|
+
if usage_metadata:
|
130
|
+
prompt_tokens = usage_metadata.get("prompt_tokens", prompt_tokens)
|
131
|
+
prompt_tokens = usage_metadata.get("input_tokens", prompt_tokens)
|
132
|
+
completion_tokens = usage_metadata.get("completion_tokens", completion_tokens)
|
133
|
+
completion_tokens = usage_metadata.get("output_tokens", completion_tokens)
|
134
|
+
total_tokens = usage_metadata.get("total_tokens", total_tokens)
|
128
135
|
else:
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
if "token_usage" not in response.llm_output:
|
133
|
-
with self._lock:
|
134
|
-
self.successful_requests += 1
|
135
|
-
return None
|
136
|
-
|
137
|
-
# compute tokens and cost for this request
|
138
|
-
token_usage = response.llm_output["token_usage"]
|
139
|
-
completion_tokens = token_usage.get("completion_tokens", 0)
|
140
|
-
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
141
|
-
model_name = response.llm_output.get("model_name", "")
|
136
|
+
with self._lock:
|
137
|
+
self.successful_requests += 1
|
138
|
+
return None
|
142
139
|
|
143
140
|
total_cost = _get_token_cost(
|
144
141
|
prompt_tokens=prompt_tokens,
|
145
142
|
completion_tokens=completion_tokens,
|
146
|
-
model_id=
|
143
|
+
model_id=model_id,
|
147
144
|
)
|
148
145
|
|
149
146
|
# update shared state behind lock
|
150
147
|
with self._lock:
|
151
148
|
self.total_cost += total_cost
|
152
|
-
self.total_tokens +=
|
149
|
+
self.total_tokens += total_tokens
|
153
150
|
self.prompt_tokens += prompt_tokens
|
154
151
|
self.completion_tokens += completion_tokens
|
155
152
|
self.successful_requests += 1
|
janus/llm/models_info.py
CHANGED
@@ -14,6 +14,7 @@ from ..prompts.prompt import (
|
|
14
14
|
CoherePromptEngine,
|
15
15
|
Llama2PromptEngine,
|
16
16
|
Llama3PromptEngine,
|
17
|
+
MistralPromptEngine,
|
17
18
|
PromptEngine,
|
18
19
|
TitanPromptEngine,
|
19
20
|
)
|
@@ -85,12 +86,18 @@ titan_models = [
|
|
85
86
|
cohere_models = [
|
86
87
|
"bedrock-command-r-plus",
|
87
88
|
]
|
89
|
+
mistral_models = [
|
90
|
+
"bedrock-mistral-7b-instruct",
|
91
|
+
"bedrock-mistral-large",
|
92
|
+
"bedrock-mixtral",
|
93
|
+
]
|
88
94
|
bedrock_models = [
|
89
95
|
*claude_models,
|
90
96
|
*llama2_models,
|
91
97
|
*llama3_models,
|
92
98
|
*titan_models,
|
93
99
|
*cohere_models,
|
100
|
+
*mistral_models,
|
94
101
|
]
|
95
102
|
all_models = [*openai_models, *bedrock_models]
|
96
103
|
|
@@ -118,6 +125,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
118
125
|
**{m: Llama3PromptEngine for m in llama3_models},
|
119
126
|
**{m: TitanPromptEngine for m in titan_models},
|
120
127
|
**{m: CoherePromptEngine for m in cohere_models},
|
128
|
+
**{m: MistralPromptEngine for m in mistral_models},
|
121
129
|
}
|
122
130
|
|
123
131
|
_open_ai_defaults: dict[str, str] = {
|
@@ -142,6 +150,9 @@ model_identifiers = {
|
|
142
150
|
"bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
|
143
151
|
"bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
|
144
152
|
"bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
|
153
|
+
"bedrock-mixtral": "mistral.mixtral-8x7b-instruct-v0:1",
|
154
|
+
"bedrock-mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
|
155
|
+
"bedrock-mistral-large": "mistral.mistral-large-2402-v1:0",
|
145
156
|
}
|
146
157
|
|
147
158
|
MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
|
@@ -182,6 +193,9 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
182
193
|
"ai21.j2-mid-v1": 8192,
|
183
194
|
"ai21.j2-ultra-v1": 8192,
|
184
195
|
"cohere.command-r-plus-v1:0": 128_000,
|
196
|
+
"mistral.mixtral-8x7b-instruct-v0:1": 32_000,
|
197
|
+
"mistral.mistral-7b-instruct-v0:2": 32_000,
|
198
|
+
"mistral.mistral-large-2402-v1:0": 32_000,
|
185
199
|
}
|
186
200
|
|
187
201
|
|