janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -51
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +50 -0
- janus/converter/translate.py +108 -0
- janus/embedding/_tests/test_collections.py +2 -2
- janus/language/_tests/test_splitter.py +1 -1
- janus/language/alc/__init__.py +1 -0
- janus/language/alc/_tests/__init__.py +0 -0
- janus/language/alc/_tests/test_alc.py +28 -0
- janus/language/alc/alc.py +87 -0
- janus/language/block.py +4 -2
- janus/language/combine.py +0 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/__init__.py +1 -1
- janus/language/naive/basic_splitter.py +4 -4
- janus/language/naive/chunk_splitter.py +4 -4
- janus/language/naive/registry.py +1 -1
- janus/language/naive/simple_ast.py +23 -12
- janus/language/naive/tag_splitter.py +4 -4
- janus/language/splitter.py +10 -4
- janus/language/treesitter/treesitter.py +26 -8
- janus/llm/model_callbacks.py +34 -37
- janus/llm/models_info.py +16 -3
- janus/metrics/_tests/test_llm.py +2 -3
- janus/metrics/_tests/test_rouge_score.py +1 -1
- janus/metrics/_tests/test_similarity_score.py +1 -1
- janus/metrics/complexity_metrics.py +3 -4
- janus/metrics/metric.py +3 -4
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +67 -7
- janus/utils/enums.py +6 -5
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
- janus/converter.py +0 -158
- janus/translate.py +0 -981
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
@@ -1,18 +1,29 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
1
|
+
from ...utils.enums import LANGUAGES
|
2
|
+
from ..alc.alc import AlcSplitter
|
3
|
+
from ..mumps.mumps import MumpsSplitter
|
4
|
+
from ..treesitter import TreeSitterSplitter
|
5
|
+
from .registry import register_splitter
|
4
6
|
|
5
7
|
|
6
8
|
@register_splitter("ast-flex")
|
7
|
-
|
8
|
-
|
9
|
+
def get_flexible_ast(language: str, **kwargs):
|
10
|
+
if language == "ibmhlasm":
|
11
|
+
return AlcSplitter(**kwargs)
|
12
|
+
elif language == "mumps":
|
13
|
+
return MumpsSplitter(**kwargs)
|
14
|
+
else:
|
15
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
9
16
|
|
10
17
|
|
11
18
|
@register_splitter("ast-strict")
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
+
def get_strict_ast(language: str, **kwargs):
|
20
|
+
kwargs.update(
|
21
|
+
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
22
|
+
prune_unprotected=True,
|
23
|
+
)
|
24
|
+
if language == "ibmhlasm":
|
25
|
+
return AlcSplitter(**kwargs)
|
26
|
+
elif language == "mumps":
|
27
|
+
return MumpsSplitter(**kwargs)
|
28
|
+
else:
|
29
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
@@ -1,7 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
1
|
+
from ..block import CodeBlock
|
2
|
+
from ..node import NodeType
|
3
|
+
from ..splitter import Splitter
|
4
|
+
from .registry import register_splitter
|
5
5
|
|
6
6
|
|
7
7
|
@register_splitter("tag")
|
janus/language/splitter.py
CHANGED
@@ -47,8 +47,8 @@ class Splitter(FileManager):
|
|
47
47
|
model: None | BaseLanguageModel = None,
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
skip_merge: bool = False,
|
50
|
-
protected_node_types: tuple[str] = (),
|
51
|
-
prune_node_types: tuple[str] = (),
|
50
|
+
protected_node_types: tuple[str, ...] = (),
|
51
|
+
prune_node_types: tuple[str, ...] = (),
|
52
52
|
prune_unprotected: bool = False,
|
53
53
|
):
|
54
54
|
"""
|
@@ -340,7 +340,10 @@ class Splitter(FileManager):
|
|
340
340
|
# Double check length (in theory this should never be an issue)
|
341
341
|
tokens = self._count_tokens(text)
|
342
342
|
if tokens > self.max_tokens:
|
343
|
-
log.error(
|
343
|
+
log.error(
|
344
|
+
f"Merged node ({name}) too long for context!"
|
345
|
+
f" ({tokens} > {self.max_tokens})"
|
346
|
+
)
|
344
347
|
|
345
348
|
return CodeBlock(
|
346
349
|
text=text,
|
@@ -420,7 +423,10 @@ class Splitter(FileManager):
|
|
420
423
|
name = f"{node.name}-L#{node_line}"
|
421
424
|
tokens = self._count_tokens(line)
|
422
425
|
if tokens > self.max_tokens:
|
423
|
-
raise TokenLimitError(
|
426
|
+
raise TokenLimitError(
|
427
|
+
"Irreducible node too large for context!"
|
428
|
+
f" ({tokens} > {self.max_tokens})"
|
429
|
+
)
|
424
430
|
|
425
431
|
node.children.append(
|
426
432
|
CodeBlock(
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import platform
|
3
3
|
from collections import defaultdict
|
4
|
+
from ctypes import c_void_p, cdll
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Optional
|
6
7
|
|
@@ -26,8 +27,8 @@ class TreeSitterSplitter(Splitter):
|
|
26
27
|
language: str,
|
27
28
|
model: None | BaseLanguageModel = None,
|
28
29
|
max_tokens: int = 4096,
|
29
|
-
protected_node_types: tuple[str] = (),
|
30
|
-
prune_node_types: tuple[str] = (),
|
30
|
+
protected_node_types: tuple[str, ...] = (),
|
31
|
+
prune_node_types: tuple[str, ...] = (),
|
31
32
|
prune_unprotected: bool = False,
|
32
33
|
) -> None:
|
33
34
|
"""Initialize a TreeSitterSplitter instance.
|
@@ -48,10 +49,10 @@ class TreeSitterSplitter(Splitter):
|
|
48
49
|
self._load_parser()
|
49
50
|
|
50
51
|
def _get_ast(self, code: str) -> CodeBlock:
|
51
|
-
|
52
|
-
tree = self.parser.parse(
|
52
|
+
code_bytes = bytes(code, "utf-8")
|
53
|
+
tree = self.parser.parse(code_bytes)
|
53
54
|
root = tree.walk().node
|
54
|
-
root = self._node_to_block(root,
|
55
|
+
root = self._node_to_block(root, code_bytes)
|
55
56
|
return root
|
56
57
|
|
57
58
|
# Recursively print tree to view parsed output (dev helper function)
|
@@ -98,7 +99,7 @@ class TreeSitterSplitter(Splitter):
|
|
98
99
|
|
99
100
|
text = node.text.decode()
|
100
101
|
children = [self._node_to_block(child, original_text) for child in node.children]
|
101
|
-
|
102
|
+
return CodeBlock(
|
102
103
|
id=node.id,
|
103
104
|
name=str(node.id),
|
104
105
|
text=text,
|
@@ -112,7 +113,6 @@ class TreeSitterSplitter(Splitter):
|
|
112
113
|
language=self.language,
|
113
114
|
tokens=self._count_tokens(text),
|
114
115
|
)
|
115
|
-
return node
|
116
116
|
|
117
117
|
def _load_parser(self) -> None:
|
118
118
|
"""Load the parser for the given language.
|
@@ -139,7 +139,25 @@ class TreeSitterSplitter(Splitter):
|
|
139
139
|
|
140
140
|
# Load the parser using the generated .so file
|
141
141
|
self.parser: tree_sitter.Parser = tree_sitter.Parser()
|
142
|
-
self.
|
142
|
+
pointer = self._so_to_pointer(so_file)
|
143
|
+
self.parser.set_language(tree_sitter.Language(pointer, self.language))
|
144
|
+
|
145
|
+
def _so_to_pointer(self, so_file: str) -> int:
|
146
|
+
"""Convert the .so file to a pointer.
|
147
|
+
|
148
|
+
Taken from `treesitter.Language.__init__` to get past deprecated warning.
|
149
|
+
|
150
|
+
Arguments:
|
151
|
+
so_file: The path to the so file for the language.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
The pointer to the language.
|
155
|
+
"""
|
156
|
+
lib = cdll.LoadLibrary(os.fspath(so_file))
|
157
|
+
language_function = getattr(lib, f"tree_sitter_{self.language}")
|
158
|
+
language_function.restype = c_void_p
|
159
|
+
pointer = language_function()
|
160
|
+
return pointer
|
143
161
|
|
144
162
|
def _create_parser(self, so_file: Path | str) -> None:
|
145
163
|
"""Create the parser for the given language.
|
janus/llm/model_callbacks.py
CHANGED
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
|
|
8
8
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
9
9
|
from langchain_core.tracers.context import register_configure_hook
|
10
10
|
|
11
|
-
from
|
11
|
+
from ..utils.logger import create_logger
|
12
12
|
|
13
13
|
log = create_logger(__name__)
|
14
14
|
|
@@ -35,6 +35,9 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
|
35
35
|
"ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
|
36
36
|
"ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
|
37
37
|
"cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
|
38
|
+
"mistral.mistral-7b-instruct-v0:2": {"input": 0.00015, "output": 0.0002},
|
39
|
+
"mistral.mixtral-8x7b-instruct-v0:1": {"input": 0.00045, "output": 0.0007},
|
40
|
+
"mistral.mistral-large-2402-v1:0": {"input": 0.004, "output": 0.012},
|
38
41
|
}
|
39
42
|
|
40
43
|
|
@@ -103,53 +106,47 @@ class TokenUsageCallbackHandler(BaseCallbackHandler):
|
|
103
106
|
generation = response.generations[0][0]
|
104
107
|
except IndexError:
|
105
108
|
generation = None
|
106
|
-
if isinstance(generation, ChatGeneration):
|
107
|
-
try:
|
108
|
-
message = generation.message
|
109
|
-
if isinstance(message, AIMessage):
|
110
|
-
usage_metadata = message.usage_metadata
|
111
|
-
else:
|
112
|
-
usage_metadata = None
|
113
|
-
except AttributeError:
|
114
|
-
usage_metadata = None
|
115
|
-
else:
|
116
|
-
usage_metadata = None
|
117
|
-
if usage_metadata:
|
118
|
-
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
119
|
-
completion_tokens = usage_metadata["output_tokens"]
|
120
|
-
prompt_tokens = usage_metadata["input_tokens"]
|
121
|
-
if response.llm_output is None:
|
122
|
-
# model name (and therefore cost) is unavailable in
|
123
|
-
# streaming responses
|
124
|
-
model_name = ""
|
125
|
-
else:
|
126
|
-
model_name = response.llm_output.get("model_name", "")
|
127
109
|
|
110
|
+
model_id = ""
|
111
|
+
usage_metadata = None
|
112
|
+
if hasattr(response, "llm_output") and response.llm_output is not None:
|
113
|
+
model_id = response.llm_output.get("model_id", model_id)
|
114
|
+
model_id = response.llm_output.get("model_name", model_id)
|
115
|
+
usage_metadata = response.llm_output.get("usage", usage_metadata)
|
116
|
+
usage_metadata = response.llm_output.get("token_usage", usage_metadata)
|
117
|
+
elif isinstance(generation, ChatGeneration):
|
118
|
+
if hasattr(generation, "response_metadata"):
|
119
|
+
model_id = generation.response_metadata.get("model_id", model_id)
|
120
|
+
model_id = generation.response_metadata.get("model_name", model_id)
|
121
|
+
usage_metadata = generation.response_metadata.get("usage", usage_metadata)
|
122
|
+
elif hasattr(generation, "message"):
|
123
|
+
if isinstance(generation.message, AIMessage):
|
124
|
+
usage_metadata = generation.message.usage_metadata
|
125
|
+
|
126
|
+
completion_tokens = 0
|
127
|
+
prompt_tokens = 0
|
128
|
+
total_tokens = 0
|
129
|
+
if usage_metadata:
|
130
|
+
prompt_tokens = usage_metadata.get("prompt_tokens", prompt_tokens)
|
131
|
+
prompt_tokens = usage_metadata.get("input_tokens", prompt_tokens)
|
132
|
+
completion_tokens = usage_metadata.get("completion_tokens", completion_tokens)
|
133
|
+
completion_tokens = usage_metadata.get("output_tokens", completion_tokens)
|
134
|
+
total_tokens = usage_metadata.get("total_tokens", total_tokens)
|
128
135
|
else:
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
if "token_usage" not in response.llm_output:
|
133
|
-
with self._lock:
|
134
|
-
self.successful_requests += 1
|
135
|
-
return None
|
136
|
-
|
137
|
-
# compute tokens and cost for this request
|
138
|
-
token_usage = response.llm_output["token_usage"]
|
139
|
-
completion_tokens = token_usage.get("completion_tokens", 0)
|
140
|
-
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
141
|
-
model_name = response.llm_output.get("model_name", "")
|
136
|
+
with self._lock:
|
137
|
+
self.successful_requests += 1
|
138
|
+
return None
|
142
139
|
|
143
140
|
total_cost = _get_token_cost(
|
144
141
|
prompt_tokens=prompt_tokens,
|
145
142
|
completion_tokens=completion_tokens,
|
146
|
-
model_id=
|
143
|
+
model_id=model_id,
|
147
144
|
)
|
148
145
|
|
149
146
|
# update shared state behind lock
|
150
147
|
with self._lock:
|
151
148
|
self.total_cost += total_cost
|
152
|
-
self.total_tokens +=
|
149
|
+
self.total_tokens += total_tokens
|
153
150
|
self.prompt_tokens += prompt_tokens
|
154
151
|
self.completion_tokens += completion_tokens
|
155
152
|
self.successful_requests += 1
|
janus/llm/models_info.py
CHANGED
@@ -8,18 +8,18 @@ from langchain_community.llms import HuggingFaceTextGenInference
|
|
8
8
|
from langchain_core.language_models import BaseLanguageModel
|
9
9
|
from langchain_openai import ChatOpenAI
|
10
10
|
|
11
|
-
from
|
12
|
-
from janus.prompts.prompt import (
|
11
|
+
from ..prompts.prompt import (
|
13
12
|
ChatGptPromptEngine,
|
14
13
|
ClaudePromptEngine,
|
15
14
|
CoherePromptEngine,
|
16
15
|
Llama2PromptEngine,
|
17
16
|
Llama3PromptEngine,
|
17
|
+
MistralPromptEngine,
|
18
18
|
PromptEngine,
|
19
19
|
TitanPromptEngine,
|
20
20
|
)
|
21
|
-
|
22
21
|
from ..utils.logger import create_logger
|
22
|
+
from .model_callbacks import COST_PER_1K_TOKENS
|
23
23
|
|
24
24
|
log = create_logger(__name__)
|
25
25
|
|
@@ -86,12 +86,18 @@ titan_models = [
|
|
86
86
|
cohere_models = [
|
87
87
|
"bedrock-command-r-plus",
|
88
88
|
]
|
89
|
+
mistral_models = [
|
90
|
+
"bedrock-mistral-7b-instruct",
|
91
|
+
"bedrock-mistral-large",
|
92
|
+
"bedrock-mixtral",
|
93
|
+
]
|
89
94
|
bedrock_models = [
|
90
95
|
*claude_models,
|
91
96
|
*llama2_models,
|
92
97
|
*llama3_models,
|
93
98
|
*titan_models,
|
94
99
|
*cohere_models,
|
100
|
+
*mistral_models,
|
95
101
|
]
|
96
102
|
all_models = [*openai_models, *bedrock_models]
|
97
103
|
|
@@ -119,6 +125,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
119
125
|
**{m: Llama3PromptEngine for m in llama3_models},
|
120
126
|
**{m: TitanPromptEngine for m in titan_models},
|
121
127
|
**{m: CoherePromptEngine for m in cohere_models},
|
128
|
+
**{m: MistralPromptEngine for m in mistral_models},
|
122
129
|
}
|
123
130
|
|
124
131
|
_open_ai_defaults: dict[str, str] = {
|
@@ -143,6 +150,9 @@ model_identifiers = {
|
|
143
150
|
"bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
|
144
151
|
"bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
|
145
152
|
"bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
|
153
|
+
"bedrock-mixtral": "mistral.mixtral-8x7b-instruct-v0:1",
|
154
|
+
"bedrock-mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
|
155
|
+
"bedrock-mistral-large": "mistral.mistral-large-2402-v1:0",
|
146
156
|
}
|
147
157
|
|
148
158
|
MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
|
@@ -183,6 +193,9 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
183
193
|
"ai21.j2-mid-v1": 8192,
|
184
194
|
"ai21.j2-ultra-v1": 8192,
|
185
195
|
"cohere.command-r-plus-v1:0": 128_000,
|
196
|
+
"mistral.mixtral-8x7b-instruct-v0:1": 32_000,
|
197
|
+
"mistral.mistral-7b-instruct-v0:2": 32_000,
|
198
|
+
"mistral.mistral-large-2402-v1:0": 32_000,
|
186
199
|
}
|
187
200
|
|
188
201
|
|
janus/metrics/_tests/test_llm.py
CHANGED
@@ -3,8 +3,7 @@ from unittest.mock import patch
|
|
3
3
|
|
4
4
|
import pytest
|
5
5
|
|
6
|
-
from
|
7
|
-
|
6
|
+
from ...llm.models_info import load_model
|
8
7
|
from ..llm_metrics import llm_evaluate_option, llm_evaluate_ref_option
|
9
8
|
|
10
9
|
|
@@ -40,7 +39,7 @@ class TestLLMMetrics(unittest.TestCase):
|
|
40
39
|
print("'Hello, world!")
|
41
40
|
"""
|
42
41
|
|
43
|
-
@patch("
|
42
|
+
@patch(".llm.models_info.load_model")
|
44
43
|
@patch("janus.metrics.llm_metrics.llm_evaluate")
|
45
44
|
@pytest.mark.llm_eval
|
46
45
|
def test_llm_self_eval_quality(self, mock_llm_evaluate, mock_load_model):
|
@@ -1,10 +1,9 @@
|
|
1
1
|
import math
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
7
|
-
|
4
|
+
from ..language.block import CodeBlock
|
5
|
+
from ..language.treesitter.treesitter import TreeSitterSplitter
|
6
|
+
from ..utils.enums import LANGUAGES
|
8
7
|
from .metric import metric
|
9
8
|
|
10
9
|
|
janus/metrics/metric.py
CHANGED
@@ -7,10 +7,9 @@ import click
|
|
7
7
|
import typer
|
8
8
|
from typing_extensions import Annotated
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
|
10
|
+
from ..llm import load_model
|
11
|
+
from ..utils.enums import LANGUAGES
|
12
|
+
from ..utils.logger import create_logger
|
14
13
|
from ..utils.progress import track
|
15
14
|
from .cli import evaluate
|
16
15
|
from .file_pairing import FILE_PAIRING_METHODS
|
janus/metrics/reading.py
CHANGED
@@ -1,9 +1,30 @@
|
|
1
|
+
import re
|
2
|
+
|
1
3
|
import nltk
|
2
4
|
import readability
|
5
|
+
from nltk.tokenize import TweetTokenizer
|
3
6
|
|
4
7
|
from .metric import metric
|
5
8
|
|
6
9
|
|
10
|
+
def word_count(text):
|
11
|
+
"""Calculates word count exactly how readability package does
|
12
|
+
|
13
|
+
Arguments:
|
14
|
+
text: The input string.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
Word Count
|
18
|
+
"""
|
19
|
+
tokenizer = TweetTokenizer()
|
20
|
+
word_count = 0
|
21
|
+
tokens = tokenizer.tokenize(text)
|
22
|
+
for t in tokens:
|
23
|
+
if not re.match(r"^[.,\/#!$%'\^&\*;:{}=\-_`~()]+$", t):
|
24
|
+
word_count += 1
|
25
|
+
return word_count
|
26
|
+
|
27
|
+
|
7
28
|
def _repeat_text(text):
|
8
29
|
"""Repeats a string until its length is over 100 words.
|
9
30
|
|
@@ -20,11 +41,10 @@ def _repeat_text(text):
|
|
20
41
|
if not text.endswith("."):
|
21
42
|
text += "." # Add a period if missing
|
22
43
|
|
23
|
-
# Check if repeated text is long enough, repeat more if needed
|
24
44
|
repeated_text = text
|
25
|
-
while len(repeated_text.split()) < 100:
|
26
|
-
repeated_text += " " + text
|
27
45
|
|
46
|
+
while word_count(repeated_text) < 100:
|
47
|
+
repeated_text += " " + text
|
28
48
|
return repeated_text
|
29
49
|
|
30
50
|
|
@@ -52,7 +72,8 @@ def flesch(target: str, **kwargs) -> float:
|
|
52
72
|
Returns:
|
53
73
|
The Flesch score.
|
54
74
|
"""
|
55
|
-
|
75
|
+
if not target.strip(): # Check if the target text is blank
|
76
|
+
return None
|
56
77
|
return get_readability(target).flesch().score
|
57
78
|
|
58
79
|
|
@@ -66,5 +87,6 @@ def gunning_fog(target: str, **kwargs) -> float:
|
|
66
87
|
Returns:
|
67
88
|
The Gunning-Fog score.
|
68
89
|
"""
|
69
|
-
|
90
|
+
if not target.strip(): # Check if the target text is blank
|
91
|
+
return None
|
70
92
|
return get_readability(target).gunning_fog().score
|
janus/prompts/prompt.py
CHANGED
@@ -2,12 +2,12 @@ import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
-
from langchain import PromptTemplate
|
6
5
|
from langchain.prompts import ChatPromptTemplate
|
7
6
|
from langchain.prompts.chat import (
|
8
7
|
HumanMessagePromptTemplate,
|
9
8
|
SystemMessagePromptTemplate,
|
10
9
|
)
|
10
|
+
from langchain_core.prompts import PromptTemplate
|
11
11
|
|
12
12
|
from ..utils.enums import LANGUAGES
|
13
13
|
from ..utils.logger import create_logger
|
@@ -34,15 +34,59 @@ HUMAN_PROMPT_TEMPLATE_FILENAME = "human.txt"
|
|
34
34
|
PROMPT_VARIABLES_FILENAME = "variables.json"
|
35
35
|
|
36
36
|
|
37
|
+
retry_with_output_prompt_text = """Instructions:
|
38
|
+
--------------
|
39
|
+
{instructions}
|
40
|
+
--------------
|
41
|
+
Completion:
|
42
|
+
--------------
|
43
|
+
{completion}
|
44
|
+
--------------
|
45
|
+
|
46
|
+
Above, the Completion did not satisfy the constraints given in the Instructions.
|
47
|
+
Error:
|
48
|
+
--------------
|
49
|
+
{error}
|
50
|
+
--------------
|
51
|
+
|
52
|
+
Please try again. Please only respond with an answer that satisfies the
|
53
|
+
constraints laid out in the Instructions:"""
|
54
|
+
|
55
|
+
|
56
|
+
retry_with_error_and_output_prompt_text = """Prompt:
|
57
|
+
--------------
|
58
|
+
{prompt}
|
59
|
+
--------------
|
60
|
+
Completion:
|
61
|
+
--------------
|
62
|
+
{completion}
|
63
|
+
--------------
|
64
|
+
|
65
|
+
Above, the Completion did not satisfy the constraints given in the Prompt.
|
66
|
+
Error:
|
67
|
+
--------------
|
68
|
+
{error}
|
69
|
+
--------------
|
70
|
+
|
71
|
+
Please try again. Please only respond with an answer that satisfies the
|
72
|
+
constraints laid out in the Prompt:"""
|
73
|
+
|
74
|
+
|
75
|
+
retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
|
76
|
+
retry_with_error_and_output_prompt = PromptTemplate.from_template(
|
77
|
+
retry_with_error_and_output_prompt_text
|
78
|
+
)
|
79
|
+
|
80
|
+
|
37
81
|
class PromptEngine(ABC):
|
38
82
|
"""A class defining prompting schemes for the LLM."""
|
39
83
|
|
40
84
|
def __init__(
|
41
85
|
self,
|
42
86
|
source_language: str,
|
43
|
-
target_language: str,
|
44
|
-
target_version: str,
|
45
87
|
prompt_template: str,
|
88
|
+
target_language: str | None = None,
|
89
|
+
target_version: str | None = None,
|
46
90
|
) -> None:
|
47
91
|
"""Initialize a PromptEngine instance.
|
48
92
|
|
@@ -63,15 +107,18 @@ class PromptEngine(ABC):
|
|
63
107
|
|
64
108
|
# Define variables to be passed in to the prompt formatter
|
65
109
|
source_language = source_language.lower()
|
66
|
-
target_language = target_language.lower()
|
67
110
|
self.variables = dict(
|
68
111
|
SOURCE_LANGUAGE=source_language,
|
69
|
-
TARGET_LANGUAGE=target_language,
|
70
|
-
TARGET_LANGUAGE_VERSION=str(target_version),
|
71
112
|
FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
|
72
113
|
SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
|
73
|
-
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
74
114
|
)
|
115
|
+
if target_language is not None:
|
116
|
+
target_language = target_language.lower()
|
117
|
+
self.variables.update(
|
118
|
+
TARGET_LANGUAGE=target_language,
|
119
|
+
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
120
|
+
)
|
121
|
+
self.variables.update(TARGET_LANGUAGE_VERSION=str(target_version))
|
75
122
|
variables_path = template_path / PROMPT_VARIABLES_FILENAME
|
76
123
|
if variables_path.exists():
|
77
124
|
self.variables.update(json.loads(variables_path.read_text()))
|
@@ -219,3 +266,16 @@ class CoherePromptEngine(PromptEngine):
|
|
219
266
|
f"{human_prompt}"
|
220
267
|
f"<|END_OF_TURN_TOKEN|>"
|
221
268
|
)
|
269
|
+
|
270
|
+
|
271
|
+
class MistralPromptEngine(PromptEngine):
|
272
|
+
def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
|
273
|
+
system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
|
274
|
+
system_prompt = system_prompt_path.read_text()
|
275
|
+
|
276
|
+
human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
|
277
|
+
human_prompt = human_prompt_path.read_text()
|
278
|
+
|
279
|
+
return PromptTemplate.from_template(
|
280
|
+
f"<s>[INST] {system_prompt} [/INST] </s>[INST] {human_prompt} [/INST]"
|
281
|
+
)
|
janus/utils/enums.py
CHANGED
@@ -10,7 +10,7 @@ class EmbeddingType(Enum):
|
|
10
10
|
TARGET = 5 # placeholder embeddings, are these useful for analysis?
|
11
11
|
|
12
12
|
|
13
|
-
CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary"}
|
13
|
+
CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary", "ibmhlasm"}
|
14
14
|
|
15
15
|
LANGUAGES: Dict[str, Dict[str, Any]] = {
|
16
16
|
"ada": {
|
@@ -63,7 +63,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
63
63
|
'#include <stdio.h>\n\nint main() {\n printf("Hello, World!\\n");\n'
|
64
64
|
" return 0;\n}\n"
|
65
65
|
),
|
66
|
-
"
|
66
|
+
"functional_node_types": ["function_definition"],
|
67
67
|
"comment_node_type": "comment",
|
68
68
|
},
|
69
69
|
"capnp": {
|
@@ -206,7 +206,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
206
206
|
"example": (
|
207
207
|
"program HelloWorld\n print *, 'Hello, World!'\nend program HelloWorld\n"
|
208
208
|
),
|
209
|
-
"
|
209
|
+
"functional_node_types": ["function"],
|
210
210
|
"comment_node_type": "comment",
|
211
211
|
},
|
212
212
|
"gitattributes": {
|
@@ -300,6 +300,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
300
300
|
END HELLO
|
301
301
|
"""
|
302
302
|
),
|
303
|
+
"functional_node_types": ["csect", "dsect"],
|
303
304
|
"branch_node_types": ["branch_instruction"],
|
304
305
|
"operation_node_types": ["operation", "branch_operation"],
|
305
306
|
"operand_node_types": ["operands"],
|
@@ -420,7 +421,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
420
421
|
"suffix": "m",
|
421
422
|
"url": "https://github.com/janus-llm/tree-sitter-mumps",
|
422
423
|
"example": 'WRITE "Hello, World!"',
|
423
|
-
"
|
424
|
+
"functional_node_types": ["routine_definition"],
|
424
425
|
"comment_node_type": "comment",
|
425
426
|
"branch_node_types": ["if_statement"],
|
426
427
|
"operation_node_types": [
|
@@ -512,7 +513,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
512
513
|
"suffix": "py",
|
513
514
|
"url": "https://github.com/tree-sitter/tree-sitter-python",
|
514
515
|
"example": "# Hello, World!\nprint('Hello, World!')\n",
|
515
|
-
"
|
516
|
+
"functional_node_types": ["function_definition"],
|
516
517
|
"comment_node_type": "comment",
|
517
518
|
},
|
518
519
|
"qmljs": {
|