janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -51
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +50 -0
- janus/converter/translate.py +108 -0
- janus/embedding/_tests/test_collections.py +2 -2
- janus/language/_tests/test_splitter.py +1 -1
- janus/language/alc/__init__.py +1 -0
- janus/language/alc/_tests/__init__.py +0 -0
- janus/language/alc/_tests/test_alc.py +28 -0
- janus/language/alc/alc.py +87 -0
- janus/language/block.py +4 -2
- janus/language/combine.py +0 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/__init__.py +1 -1
- janus/language/naive/basic_splitter.py +4 -4
- janus/language/naive/chunk_splitter.py +4 -4
- janus/language/naive/registry.py +1 -1
- janus/language/naive/simple_ast.py +23 -12
- janus/language/naive/tag_splitter.py +4 -4
- janus/language/splitter.py +10 -4
- janus/language/treesitter/treesitter.py +26 -8
- janus/llm/model_callbacks.py +34 -37
- janus/llm/models_info.py +16 -3
- janus/metrics/_tests/test_llm.py +2 -3
- janus/metrics/_tests/test_rouge_score.py +1 -1
- janus/metrics/_tests/test_similarity_score.py +1 -1
- janus/metrics/complexity_metrics.py +3 -4
- janus/metrics/metric.py +3 -4
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +67 -7
- janus/utils/enums.py +6 -5
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
- janus/converter.py +0 -158
- janus/translate.py +0 -981
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
@@ -1,18 +1,29 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
1
|
+
from ...utils.enums import LANGUAGES
|
2
|
+
from ..alc.alc import AlcSplitter
|
3
|
+
from ..mumps.mumps import MumpsSplitter
|
4
|
+
from ..treesitter import TreeSitterSplitter
|
5
|
+
from .registry import register_splitter
|
4
6
|
|
5
7
|
|
6
8
|
@register_splitter("ast-flex")
|
7
|
-
|
8
|
-
|
9
|
+
def get_flexible_ast(language: str, **kwargs):
|
10
|
+
if language == "ibmhlasm":
|
11
|
+
return AlcSplitter(**kwargs)
|
12
|
+
elif language == "mumps":
|
13
|
+
return MumpsSplitter(**kwargs)
|
14
|
+
else:
|
15
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
9
16
|
|
10
17
|
|
11
18
|
@register_splitter("ast-strict")
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
+
def get_strict_ast(language: str, **kwargs):
|
20
|
+
kwargs.update(
|
21
|
+
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
22
|
+
prune_unprotected=True,
|
23
|
+
)
|
24
|
+
if language == "ibmhlasm":
|
25
|
+
return AlcSplitter(**kwargs)
|
26
|
+
elif language == "mumps":
|
27
|
+
return MumpsSplitter(**kwargs)
|
28
|
+
else:
|
29
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
@@ -1,7 +1,7 @@
|
|
1
|
-
from
|
2
|
-
from
|
3
|
-
from
|
4
|
-
from
|
1
|
+
from ..block import CodeBlock
|
2
|
+
from ..node import NodeType
|
3
|
+
from ..splitter import Splitter
|
4
|
+
from .registry import register_splitter
|
5
5
|
|
6
6
|
|
7
7
|
@register_splitter("tag")
|
janus/language/splitter.py
CHANGED
@@ -47,8 +47,8 @@ class Splitter(FileManager):
|
|
47
47
|
model: None | BaseLanguageModel = None,
|
48
48
|
max_tokens: int = 4096,
|
49
49
|
skip_merge: bool = False,
|
50
|
-
protected_node_types: tuple[str] = (),
|
51
|
-
prune_node_types: tuple[str] = (),
|
50
|
+
protected_node_types: tuple[str, ...] = (),
|
51
|
+
prune_node_types: tuple[str, ...] = (),
|
52
52
|
prune_unprotected: bool = False,
|
53
53
|
):
|
54
54
|
"""
|
@@ -340,7 +340,10 @@ class Splitter(FileManager):
|
|
340
340
|
# Double check length (in theory this should never be an issue)
|
341
341
|
tokens = self._count_tokens(text)
|
342
342
|
if tokens > self.max_tokens:
|
343
|
-
log.error(
|
343
|
+
log.error(
|
344
|
+
f"Merged node ({name}) too long for context!"
|
345
|
+
f" ({tokens} > {self.max_tokens})"
|
346
|
+
)
|
344
347
|
|
345
348
|
return CodeBlock(
|
346
349
|
text=text,
|
@@ -420,7 +423,10 @@ class Splitter(FileManager):
|
|
420
423
|
name = f"{node.name}-L#{node_line}"
|
421
424
|
tokens = self._count_tokens(line)
|
422
425
|
if tokens > self.max_tokens:
|
423
|
-
raise TokenLimitError(
|
426
|
+
raise TokenLimitError(
|
427
|
+
"Irreducible node too large for context!"
|
428
|
+
f" ({tokens} > {self.max_tokens})"
|
429
|
+
)
|
424
430
|
|
425
431
|
node.children.append(
|
426
432
|
CodeBlock(
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import platform
|
3
3
|
from collections import defaultdict
|
4
|
+
from ctypes import c_void_p, cdll
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Optional
|
6
7
|
|
@@ -26,8 +27,8 @@ class TreeSitterSplitter(Splitter):
|
|
26
27
|
language: str,
|
27
28
|
model: None | BaseLanguageModel = None,
|
28
29
|
max_tokens: int = 4096,
|
29
|
-
protected_node_types: tuple[str] = (),
|
30
|
-
prune_node_types: tuple[str] = (),
|
30
|
+
protected_node_types: tuple[str, ...] = (),
|
31
|
+
prune_node_types: tuple[str, ...] = (),
|
31
32
|
prune_unprotected: bool = False,
|
32
33
|
) -> None:
|
33
34
|
"""Initialize a TreeSitterSplitter instance.
|
@@ -48,10 +49,10 @@ class TreeSitterSplitter(Splitter):
|
|
48
49
|
self._load_parser()
|
49
50
|
|
50
51
|
def _get_ast(self, code: str) -> CodeBlock:
|
51
|
-
|
52
|
-
tree = self.parser.parse(
|
52
|
+
code_bytes = bytes(code, "utf-8")
|
53
|
+
tree = self.parser.parse(code_bytes)
|
53
54
|
root = tree.walk().node
|
54
|
-
root = self._node_to_block(root,
|
55
|
+
root = self._node_to_block(root, code_bytes)
|
55
56
|
return root
|
56
57
|
|
57
58
|
# Recursively print tree to view parsed output (dev helper function)
|
@@ -98,7 +99,7 @@ class TreeSitterSplitter(Splitter):
|
|
98
99
|
|
99
100
|
text = node.text.decode()
|
100
101
|
children = [self._node_to_block(child, original_text) for child in node.children]
|
101
|
-
|
102
|
+
return CodeBlock(
|
102
103
|
id=node.id,
|
103
104
|
name=str(node.id),
|
104
105
|
text=text,
|
@@ -112,7 +113,6 @@ class TreeSitterSplitter(Splitter):
|
|
112
113
|
language=self.language,
|
113
114
|
tokens=self._count_tokens(text),
|
114
115
|
)
|
115
|
-
return node
|
116
116
|
|
117
117
|
def _load_parser(self) -> None:
|
118
118
|
"""Load the parser for the given language.
|
@@ -139,7 +139,25 @@ class TreeSitterSplitter(Splitter):
|
|
139
139
|
|
140
140
|
# Load the parser using the generated .so file
|
141
141
|
self.parser: tree_sitter.Parser = tree_sitter.Parser()
|
142
|
-
self.
|
142
|
+
pointer = self._so_to_pointer(so_file)
|
143
|
+
self.parser.set_language(tree_sitter.Language(pointer, self.language))
|
144
|
+
|
145
|
+
def _so_to_pointer(self, so_file: str) -> int:
|
146
|
+
"""Convert the .so file to a pointer.
|
147
|
+
|
148
|
+
Taken from `treesitter.Language.__init__` to get past deprecated warning.
|
149
|
+
|
150
|
+
Arguments:
|
151
|
+
so_file: The path to the so file for the language.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
The pointer to the language.
|
155
|
+
"""
|
156
|
+
lib = cdll.LoadLibrary(os.fspath(so_file))
|
157
|
+
language_function = getattr(lib, f"tree_sitter_{self.language}")
|
158
|
+
language_function.restype = c_void_p
|
159
|
+
pointer = language_function()
|
160
|
+
return pointer
|
143
161
|
|
144
162
|
def _create_parser(self, so_file: Path | str) -> None:
|
145
163
|
"""Create the parser for the given language.
|
janus/llm/model_callbacks.py
CHANGED
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
|
|
8
8
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
9
9
|
from langchain_core.tracers.context import register_configure_hook
|
10
10
|
|
11
|
-
from
|
11
|
+
from ..utils.logger import create_logger
|
12
12
|
|
13
13
|
log = create_logger(__name__)
|
14
14
|
|
@@ -35,6 +35,9 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
|
35
35
|
"ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
|
36
36
|
"ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
|
37
37
|
"cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
|
38
|
+
"mistral.mistral-7b-instruct-v0:2": {"input": 0.00015, "output": 0.0002},
|
39
|
+
"mistral.mixtral-8x7b-instruct-v0:1": {"input": 0.00045, "output": 0.0007},
|
40
|
+
"mistral.mistral-large-2402-v1:0": {"input": 0.004, "output": 0.012},
|
38
41
|
}
|
39
42
|
|
40
43
|
|
@@ -103,53 +106,47 @@ class TokenUsageCallbackHandler(BaseCallbackHandler):
|
|
103
106
|
generation = response.generations[0][0]
|
104
107
|
except IndexError:
|
105
108
|
generation = None
|
106
|
-
if isinstance(generation, ChatGeneration):
|
107
|
-
try:
|
108
|
-
message = generation.message
|
109
|
-
if isinstance(message, AIMessage):
|
110
|
-
usage_metadata = message.usage_metadata
|
111
|
-
else:
|
112
|
-
usage_metadata = None
|
113
|
-
except AttributeError:
|
114
|
-
usage_metadata = None
|
115
|
-
else:
|
116
|
-
usage_metadata = None
|
117
|
-
if usage_metadata:
|
118
|
-
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
119
|
-
completion_tokens = usage_metadata["output_tokens"]
|
120
|
-
prompt_tokens = usage_metadata["input_tokens"]
|
121
|
-
if response.llm_output is None:
|
122
|
-
# model name (and therefore cost) is unavailable in
|
123
|
-
# streaming responses
|
124
|
-
model_name = ""
|
125
|
-
else:
|
126
|
-
model_name = response.llm_output.get("model_name", "")
|
127
109
|
|
110
|
+
model_id = ""
|
111
|
+
usage_metadata = None
|
112
|
+
if hasattr(response, "llm_output") and response.llm_output is not None:
|
113
|
+
model_id = response.llm_output.get("model_id", model_id)
|
114
|
+
model_id = response.llm_output.get("model_name", model_id)
|
115
|
+
usage_metadata = response.llm_output.get("usage", usage_metadata)
|
116
|
+
usage_metadata = response.llm_output.get("token_usage", usage_metadata)
|
117
|
+
elif isinstance(generation, ChatGeneration):
|
118
|
+
if hasattr(generation, "response_metadata"):
|
119
|
+
model_id = generation.response_metadata.get("model_id", model_id)
|
120
|
+
model_id = generation.response_metadata.get("model_name", model_id)
|
121
|
+
usage_metadata = generation.response_metadata.get("usage", usage_metadata)
|
122
|
+
elif hasattr(generation, "message"):
|
123
|
+
if isinstance(generation.message, AIMessage):
|
124
|
+
usage_metadata = generation.message.usage_metadata
|
125
|
+
|
126
|
+
completion_tokens = 0
|
127
|
+
prompt_tokens = 0
|
128
|
+
total_tokens = 0
|
129
|
+
if usage_metadata:
|
130
|
+
prompt_tokens = usage_metadata.get("prompt_tokens", prompt_tokens)
|
131
|
+
prompt_tokens = usage_metadata.get("input_tokens", prompt_tokens)
|
132
|
+
completion_tokens = usage_metadata.get("completion_tokens", completion_tokens)
|
133
|
+
completion_tokens = usage_metadata.get("output_tokens", completion_tokens)
|
134
|
+
total_tokens = usage_metadata.get("total_tokens", total_tokens)
|
128
135
|
else:
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
if "token_usage" not in response.llm_output:
|
133
|
-
with self._lock:
|
134
|
-
self.successful_requests += 1
|
135
|
-
return None
|
136
|
-
|
137
|
-
# compute tokens and cost for this request
|
138
|
-
token_usage = response.llm_output["token_usage"]
|
139
|
-
completion_tokens = token_usage.get("completion_tokens", 0)
|
140
|
-
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
141
|
-
model_name = response.llm_output.get("model_name", "")
|
136
|
+
with self._lock:
|
137
|
+
self.successful_requests += 1
|
138
|
+
return None
|
142
139
|
|
143
140
|
total_cost = _get_token_cost(
|
144
141
|
prompt_tokens=prompt_tokens,
|
145
142
|
completion_tokens=completion_tokens,
|
146
|
-
model_id=
|
143
|
+
model_id=model_id,
|
147
144
|
)
|
148
145
|
|
149
146
|
# update shared state behind lock
|
150
147
|
with self._lock:
|
151
148
|
self.total_cost += total_cost
|
152
|
-
self.total_tokens +=
|
149
|
+
self.total_tokens += total_tokens
|
153
150
|
self.prompt_tokens += prompt_tokens
|
154
151
|
self.completion_tokens += completion_tokens
|
155
152
|
self.successful_requests += 1
|
janus/llm/models_info.py
CHANGED
@@ -8,18 +8,18 @@ from langchain_community.llms import HuggingFaceTextGenInference
|
|
8
8
|
from langchain_core.language_models import BaseLanguageModel
|
9
9
|
from langchain_openai import ChatOpenAI
|
10
10
|
|
11
|
-
from
|
12
|
-
from janus.prompts.prompt import (
|
11
|
+
from ..prompts.prompt import (
|
13
12
|
ChatGptPromptEngine,
|
14
13
|
ClaudePromptEngine,
|
15
14
|
CoherePromptEngine,
|
16
15
|
Llama2PromptEngine,
|
17
16
|
Llama3PromptEngine,
|
17
|
+
MistralPromptEngine,
|
18
18
|
PromptEngine,
|
19
19
|
TitanPromptEngine,
|
20
20
|
)
|
21
|
-
|
22
21
|
from ..utils.logger import create_logger
|
22
|
+
from .model_callbacks import COST_PER_1K_TOKENS
|
23
23
|
|
24
24
|
log = create_logger(__name__)
|
25
25
|
|
@@ -86,12 +86,18 @@ titan_models = [
|
|
86
86
|
cohere_models = [
|
87
87
|
"bedrock-command-r-plus",
|
88
88
|
]
|
89
|
+
mistral_models = [
|
90
|
+
"bedrock-mistral-7b-instruct",
|
91
|
+
"bedrock-mistral-large",
|
92
|
+
"bedrock-mixtral",
|
93
|
+
]
|
89
94
|
bedrock_models = [
|
90
95
|
*claude_models,
|
91
96
|
*llama2_models,
|
92
97
|
*llama3_models,
|
93
98
|
*titan_models,
|
94
99
|
*cohere_models,
|
100
|
+
*mistral_models,
|
95
101
|
]
|
96
102
|
all_models = [*openai_models, *bedrock_models]
|
97
103
|
|
@@ -119,6 +125,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
119
125
|
**{m: Llama3PromptEngine for m in llama3_models},
|
120
126
|
**{m: TitanPromptEngine for m in titan_models},
|
121
127
|
**{m: CoherePromptEngine for m in cohere_models},
|
128
|
+
**{m: MistralPromptEngine for m in mistral_models},
|
122
129
|
}
|
123
130
|
|
124
131
|
_open_ai_defaults: dict[str, str] = {
|
@@ -143,6 +150,9 @@ model_identifiers = {
|
|
143
150
|
"bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
|
144
151
|
"bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
|
145
152
|
"bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
|
153
|
+
"bedrock-mixtral": "mistral.mixtral-8x7b-instruct-v0:1",
|
154
|
+
"bedrock-mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
|
155
|
+
"bedrock-mistral-large": "mistral.mistral-large-2402-v1:0",
|
146
156
|
}
|
147
157
|
|
148
158
|
MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
|
@@ -183,6 +193,9 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
183
193
|
"ai21.j2-mid-v1": 8192,
|
184
194
|
"ai21.j2-ultra-v1": 8192,
|
185
195
|
"cohere.command-r-plus-v1:0": 128_000,
|
196
|
+
"mistral.mixtral-8x7b-instruct-v0:1": 32_000,
|
197
|
+
"mistral.mistral-7b-instruct-v0:2": 32_000,
|
198
|
+
"mistral.mistral-large-2402-v1:0": 32_000,
|
186
199
|
}
|
187
200
|
|
188
201
|
|
janus/metrics/_tests/test_llm.py
CHANGED
@@ -3,8 +3,7 @@ from unittest.mock import patch
|
|
3
3
|
|
4
4
|
import pytest
|
5
5
|
|
6
|
-
from
|
7
|
-
|
6
|
+
from ...llm.models_info import load_model
|
8
7
|
from ..llm_metrics import llm_evaluate_option, llm_evaluate_ref_option
|
9
8
|
|
10
9
|
|
@@ -40,7 +39,7 @@ class TestLLMMetrics(unittest.TestCase):
|
|
40
39
|
print("'Hello, world!")
|
41
40
|
"""
|
42
41
|
|
43
|
-
@patch("
|
42
|
+
@patch(".llm.models_info.load_model")
|
44
43
|
@patch("janus.metrics.llm_metrics.llm_evaluate")
|
45
44
|
@pytest.mark.llm_eval
|
46
45
|
def test_llm_self_eval_quality(self, mock_llm_evaluate, mock_load_model):
|
@@ -1,10 +1,9 @@
|
|
1
1
|
import math
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
7
|
-
|
4
|
+
from ..language.block import CodeBlock
|
5
|
+
from ..language.treesitter.treesitter import TreeSitterSplitter
|
6
|
+
from ..utils.enums import LANGUAGES
|
8
7
|
from .metric import metric
|
9
8
|
|
10
9
|
|
janus/metrics/metric.py
CHANGED
@@ -7,10 +7,9 @@ import click
|
|
7
7
|
import typer
|
8
8
|
from typing_extensions import Annotated
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
|
10
|
+
from ..llm import load_model
|
11
|
+
from ..utils.enums import LANGUAGES
|
12
|
+
from ..utils.logger import create_logger
|
14
13
|
from ..utils.progress import track
|
15
14
|
from .cli import evaluate
|
16
15
|
from .file_pairing import FILE_PAIRING_METHODS
|
janus/metrics/reading.py
CHANGED
@@ -1,9 +1,30 @@
|
|
1
|
+
import re
|
2
|
+
|
1
3
|
import nltk
|
2
4
|
import readability
|
5
|
+
from nltk.tokenize import TweetTokenizer
|
3
6
|
|
4
7
|
from .metric import metric
|
5
8
|
|
6
9
|
|
10
|
+
def word_count(text):
|
11
|
+
"""Calculates word count exactly how readability package does
|
12
|
+
|
13
|
+
Arguments:
|
14
|
+
text: The input string.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
Word Count
|
18
|
+
"""
|
19
|
+
tokenizer = TweetTokenizer()
|
20
|
+
word_count = 0
|
21
|
+
tokens = tokenizer.tokenize(text)
|
22
|
+
for t in tokens:
|
23
|
+
if not re.match(r"^[.,\/#!$%'\^&\*;:{}=\-_`~()]+$", t):
|
24
|
+
word_count += 1
|
25
|
+
return word_count
|
26
|
+
|
27
|
+
|
7
28
|
def _repeat_text(text):
|
8
29
|
"""Repeats a string until its length is over 100 words.
|
9
30
|
|
@@ -20,11 +41,10 @@ def _repeat_text(text):
|
|
20
41
|
if not text.endswith("."):
|
21
42
|
text += "." # Add a period if missing
|
22
43
|
|
23
|
-
# Check if repeated text is long enough, repeat more if needed
|
24
44
|
repeated_text = text
|
25
|
-
while len(repeated_text.split()) < 100:
|
26
|
-
repeated_text += " " + text
|
27
45
|
|
46
|
+
while word_count(repeated_text) < 100:
|
47
|
+
repeated_text += " " + text
|
28
48
|
return repeated_text
|
29
49
|
|
30
50
|
|
@@ -52,7 +72,8 @@ def flesch(target: str, **kwargs) -> float:
|
|
52
72
|
Returns:
|
53
73
|
The Flesch score.
|
54
74
|
"""
|
55
|
-
|
75
|
+
if not target.strip(): # Check if the target text is blank
|
76
|
+
return None
|
56
77
|
return get_readability(target).flesch().score
|
57
78
|
|
58
79
|
|
@@ -66,5 +87,6 @@ def gunning_fog(target: str, **kwargs) -> float:
|
|
66
87
|
Returns:
|
67
88
|
The Gunning-Fog score.
|
68
89
|
"""
|
69
|
-
|
90
|
+
if not target.strip(): # Check if the target text is blank
|
91
|
+
return None
|
70
92
|
return get_readability(target).gunning_fog().score
|
janus/prompts/prompt.py
CHANGED
@@ -2,12 +2,12 @@ import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
-
from langchain import PromptTemplate
|
6
5
|
from langchain.prompts import ChatPromptTemplate
|
7
6
|
from langchain.prompts.chat import (
|
8
7
|
HumanMessagePromptTemplate,
|
9
8
|
SystemMessagePromptTemplate,
|
10
9
|
)
|
10
|
+
from langchain_core.prompts import PromptTemplate
|
11
11
|
|
12
12
|
from ..utils.enums import LANGUAGES
|
13
13
|
from ..utils.logger import create_logger
|
@@ -34,15 +34,59 @@ HUMAN_PROMPT_TEMPLATE_FILENAME = "human.txt"
|
|
34
34
|
PROMPT_VARIABLES_FILENAME = "variables.json"
|
35
35
|
|
36
36
|
|
37
|
+
retry_with_output_prompt_text = """Instructions:
|
38
|
+
--------------
|
39
|
+
{instructions}
|
40
|
+
--------------
|
41
|
+
Completion:
|
42
|
+
--------------
|
43
|
+
{completion}
|
44
|
+
--------------
|
45
|
+
|
46
|
+
Above, the Completion did not satisfy the constraints given in the Instructions.
|
47
|
+
Error:
|
48
|
+
--------------
|
49
|
+
{error}
|
50
|
+
--------------
|
51
|
+
|
52
|
+
Please try again. Please only respond with an answer that satisfies the
|
53
|
+
constraints laid out in the Instructions:"""
|
54
|
+
|
55
|
+
|
56
|
+
retry_with_error_and_output_prompt_text = """Prompt:
|
57
|
+
--------------
|
58
|
+
{prompt}
|
59
|
+
--------------
|
60
|
+
Completion:
|
61
|
+
--------------
|
62
|
+
{completion}
|
63
|
+
--------------
|
64
|
+
|
65
|
+
Above, the Completion did not satisfy the constraints given in the Prompt.
|
66
|
+
Error:
|
67
|
+
--------------
|
68
|
+
{error}
|
69
|
+
--------------
|
70
|
+
|
71
|
+
Please try again. Please only respond with an answer that satisfies the
|
72
|
+
constraints laid out in the Prompt:"""
|
73
|
+
|
74
|
+
|
75
|
+
retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
|
76
|
+
retry_with_error_and_output_prompt = PromptTemplate.from_template(
|
77
|
+
retry_with_error_and_output_prompt_text
|
78
|
+
)
|
79
|
+
|
80
|
+
|
37
81
|
class PromptEngine(ABC):
|
38
82
|
"""A class defining prompting schemes for the LLM."""
|
39
83
|
|
40
84
|
def __init__(
|
41
85
|
self,
|
42
86
|
source_language: str,
|
43
|
-
target_language: str,
|
44
|
-
target_version: str,
|
45
87
|
prompt_template: str,
|
88
|
+
target_language: str | None = None,
|
89
|
+
target_version: str | None = None,
|
46
90
|
) -> None:
|
47
91
|
"""Initialize a PromptEngine instance.
|
48
92
|
|
@@ -63,15 +107,18 @@ class PromptEngine(ABC):
|
|
63
107
|
|
64
108
|
# Define variables to be passed in to the prompt formatter
|
65
109
|
source_language = source_language.lower()
|
66
|
-
target_language = target_language.lower()
|
67
110
|
self.variables = dict(
|
68
111
|
SOURCE_LANGUAGE=source_language,
|
69
|
-
TARGET_LANGUAGE=target_language,
|
70
|
-
TARGET_LANGUAGE_VERSION=str(target_version),
|
71
112
|
FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
|
72
113
|
SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
|
73
|
-
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
74
114
|
)
|
115
|
+
if target_language is not None:
|
116
|
+
target_language = target_language.lower()
|
117
|
+
self.variables.update(
|
118
|
+
TARGET_LANGUAGE=target_language,
|
119
|
+
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
120
|
+
)
|
121
|
+
self.variables.update(TARGET_LANGUAGE_VERSION=str(target_version))
|
75
122
|
variables_path = template_path / PROMPT_VARIABLES_FILENAME
|
76
123
|
if variables_path.exists():
|
77
124
|
self.variables.update(json.loads(variables_path.read_text()))
|
@@ -219,3 +266,16 @@ class CoherePromptEngine(PromptEngine):
|
|
219
266
|
f"{human_prompt}"
|
220
267
|
f"<|END_OF_TURN_TOKEN|>"
|
221
268
|
)
|
269
|
+
|
270
|
+
|
271
|
+
class MistralPromptEngine(PromptEngine):
|
272
|
+
def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
|
273
|
+
system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
|
274
|
+
system_prompt = system_prompt_path.read_text()
|
275
|
+
|
276
|
+
human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
|
277
|
+
human_prompt = human_prompt_path.read_text()
|
278
|
+
|
279
|
+
return PromptTemplate.from_template(
|
280
|
+
f"<s>[INST] {system_prompt} [/INST] </s>[INST] {human_prompt} [/INST]"
|
281
|
+
)
|
janus/utils/enums.py
CHANGED
@@ -10,7 +10,7 @@ class EmbeddingType(Enum):
|
|
10
10
|
TARGET = 5 # placeholder embeddings, are these useful for analysis?
|
11
11
|
|
12
12
|
|
13
|
-
CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary"}
|
13
|
+
CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary", "ibmhlasm"}
|
14
14
|
|
15
15
|
LANGUAGES: Dict[str, Dict[str, Any]] = {
|
16
16
|
"ada": {
|
@@ -63,7 +63,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
63
63
|
'#include <stdio.h>\n\nint main() {\n printf("Hello, World!\\n");\n'
|
64
64
|
" return 0;\n}\n"
|
65
65
|
),
|
66
|
-
"
|
66
|
+
"functional_node_types": ["function_definition"],
|
67
67
|
"comment_node_type": "comment",
|
68
68
|
},
|
69
69
|
"capnp": {
|
@@ -206,7 +206,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
206
206
|
"example": (
|
207
207
|
"program HelloWorld\n print *, 'Hello, World!'\nend program HelloWorld\n"
|
208
208
|
),
|
209
|
-
"
|
209
|
+
"functional_node_types": ["function"],
|
210
210
|
"comment_node_type": "comment",
|
211
211
|
},
|
212
212
|
"gitattributes": {
|
@@ -300,6 +300,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
300
300
|
END HELLO
|
301
301
|
"""
|
302
302
|
),
|
303
|
+
"functional_node_types": ["csect", "dsect"],
|
303
304
|
"branch_node_types": ["branch_instruction"],
|
304
305
|
"operation_node_types": ["operation", "branch_operation"],
|
305
306
|
"operand_node_types": ["operands"],
|
@@ -420,7 +421,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
420
421
|
"suffix": "m",
|
421
422
|
"url": "https://github.com/janus-llm/tree-sitter-mumps",
|
422
423
|
"example": 'WRITE "Hello, World!"',
|
423
|
-
"
|
424
|
+
"functional_node_types": ["routine_definition"],
|
424
425
|
"comment_node_type": "comment",
|
425
426
|
"branch_node_types": ["if_statement"],
|
426
427
|
"operation_node_types": [
|
@@ -512,7 +513,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
512
513
|
"suffix": "py",
|
513
514
|
"url": "https://github.com/tree-sitter/tree-sitter-python",
|
514
515
|
"example": "# Hello, World!\nprint('Hello, World!')\n",
|
515
|
-
"
|
516
|
+
"functional_node_types": ["function_definition"],
|
516
517
|
"comment_node_type": "comment",
|
517
518
|
},
|
518
519
|
"qmljs": {
|