janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +120 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +9 -6
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +134 -70
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
- janus_llm-2.0.0.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
import threading
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from contextvars import ContextVar
|
4
|
+
from typing import Any, Generator
|
5
|
+
|
6
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
7
|
+
from langchain_core.messages import AIMessage
|
8
|
+
from langchain_core.outputs import ChatGeneration, LLMResult
|
9
|
+
from langchain_core.tracers.context import register_configure_hook
|
10
|
+
|
11
|
+
from janus.utils.logger import create_logger
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
# Updated 2024-06-21
|
17
|
+
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
18
|
+
"gpt-3.5-turbo-0125": {"input": 0.0005, "output": 0.0015},
|
19
|
+
"gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
|
20
|
+
"gpt-4-0125-preview": {"input": 0.01, "output": 0.03},
|
21
|
+
"gpt-4-0613": {"input": 0.03, "output": 0.06},
|
22
|
+
"gpt-4o-2024-05-13": {"input": 0.005, "output": 0.015},
|
23
|
+
"anthropic.claude-v2": {"input": 0.008, "output": 0.024},
|
24
|
+
"anthropic.claude-instant-v1": {"input": 0.0008, "output": 0.0024},
|
25
|
+
"anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125},
|
26
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": {"input": 0.003, "output": 0.015},
|
27
|
+
"meta.llama2-13b-chat-v1": {"input": 0.00075, "output": 0.001},
|
28
|
+
"meta.llama2-70b-chat-v1": {"input": 0.00195, "output": 0.00256},
|
29
|
+
"meta.llama2-13b-v1": {"input": 0.0, "output": 0.0},
|
30
|
+
"meta.llama2-70b-v1": {"input": 0.00265, "output": 0.0035},
|
31
|
+
"meta.llama3-8b-instruct-v1:0": {"input": 0.0003, "output": 0.0006},
|
32
|
+
"meta.llama3-70b-instruct-v1:0": {"input": 0.00265, "output": 0.0035},
|
33
|
+
"amazon.titan-text-lite-v1": {"input": 0.00015, "output": 0.0002},
|
34
|
+
"amazon.titan-text-express-v1": {"input": 0.0002, "output": 0.0006},
|
35
|
+
"ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
|
36
|
+
"ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
|
37
|
+
"cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
def _get_token_cost(
|
42
|
+
prompt_tokens: int, completion_tokens: int, model_id: str | None
|
43
|
+
) -> float:
|
44
|
+
"""Get the cost of tokens according to model ID"""
|
45
|
+
if model_id not in COST_PER_1K_TOKENS:
|
46
|
+
raise ValueError(
|
47
|
+
f"Unknown model: {model_id}. Please provide a valid model name."
|
48
|
+
f" Known models are: {', '.join(COST_PER_1K_TOKENS.keys())}"
|
49
|
+
)
|
50
|
+
model_cost = COST_PER_1K_TOKENS[model_id]
|
51
|
+
input_cost = (prompt_tokens / 1000.0) * model_cost["input"]
|
52
|
+
output_cost = (completion_tokens / 1000.0) * model_cost["output"]
|
53
|
+
return input_cost + output_cost
|
54
|
+
|
55
|
+
|
56
|
+
class TokenUsageCallbackHandler(BaseCallbackHandler):
|
57
|
+
"""Callback Handler that tracks metadata on model cost, retries, etc.
|
58
|
+
Based on https://github.com/langchain-ai/langchain/blob/master/libs
|
59
|
+
/community/langchain_community/callbacks/openai_info.py
|
60
|
+
"""
|
61
|
+
|
62
|
+
total_tokens: int = 0
|
63
|
+
prompt_tokens: int = 0
|
64
|
+
completion_tokens: int = 0
|
65
|
+
successful_requests: int = 0
|
66
|
+
total_cost: float = 0.0
|
67
|
+
|
68
|
+
def __init__(self) -> None:
|
69
|
+
super().__init__()
|
70
|
+
self._lock = threading.Lock()
|
71
|
+
|
72
|
+
def __repr__(self) -> str:
|
73
|
+
return (
|
74
|
+
f"Tokens Used: {self.total_tokens}\n"
|
75
|
+
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
76
|
+
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
77
|
+
f"Successful Requests: {self.successful_requests}\n"
|
78
|
+
f"Total Cost (USD): ${self.total_cost}"
|
79
|
+
)
|
80
|
+
|
81
|
+
@property
|
82
|
+
def always_verbose(self) -> bool:
|
83
|
+
"""Whether to call verbose callbacks even if verbose is False."""
|
84
|
+
return True
|
85
|
+
|
86
|
+
def on_chat_model_start(self, *args, **kwargs):
|
87
|
+
pass
|
88
|
+
|
89
|
+
def on_llm_start(
|
90
|
+
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
91
|
+
) -> None:
|
92
|
+
"""Print out the prompts."""
|
93
|
+
pass
|
94
|
+
|
95
|
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
96
|
+
"""Print out the token."""
|
97
|
+
pass
|
98
|
+
|
99
|
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
100
|
+
"""Collect token usage."""
|
101
|
+
# Check for usage_metadata (langchain-core >= 0.2.2)
|
102
|
+
try:
|
103
|
+
generation = response.generations[0][0]
|
104
|
+
except IndexError:
|
105
|
+
generation = None
|
106
|
+
if isinstance(generation, ChatGeneration):
|
107
|
+
try:
|
108
|
+
message = generation.message
|
109
|
+
if isinstance(message, AIMessage):
|
110
|
+
usage_metadata = message.usage_metadata
|
111
|
+
else:
|
112
|
+
usage_metadata = None
|
113
|
+
except AttributeError:
|
114
|
+
usage_metadata = None
|
115
|
+
else:
|
116
|
+
usage_metadata = None
|
117
|
+
if usage_metadata:
|
118
|
+
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
119
|
+
completion_tokens = usage_metadata["output_tokens"]
|
120
|
+
prompt_tokens = usage_metadata["input_tokens"]
|
121
|
+
if response.llm_output is None:
|
122
|
+
# model name (and therefore cost) is unavailable in
|
123
|
+
# streaming responses
|
124
|
+
model_name = ""
|
125
|
+
else:
|
126
|
+
model_name = response.llm_output.get("model_name", "")
|
127
|
+
|
128
|
+
else:
|
129
|
+
if response.llm_output is None:
|
130
|
+
return None
|
131
|
+
|
132
|
+
if "token_usage" not in response.llm_output:
|
133
|
+
with self._lock:
|
134
|
+
self.successful_requests += 1
|
135
|
+
return None
|
136
|
+
|
137
|
+
# compute tokens and cost for this request
|
138
|
+
token_usage = response.llm_output["token_usage"]
|
139
|
+
completion_tokens = token_usage.get("completion_tokens", 0)
|
140
|
+
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
141
|
+
model_name = response.llm_output.get("model_name", "")
|
142
|
+
|
143
|
+
total_cost = _get_token_cost(
|
144
|
+
prompt_tokens=prompt_tokens,
|
145
|
+
completion_tokens=completion_tokens,
|
146
|
+
model_id=model_name,
|
147
|
+
)
|
148
|
+
|
149
|
+
# update shared state behind lock
|
150
|
+
with self._lock:
|
151
|
+
self.total_cost += total_cost
|
152
|
+
self.total_tokens += token_usage.get("total_tokens", 0)
|
153
|
+
self.prompt_tokens += prompt_tokens
|
154
|
+
self.completion_tokens += completion_tokens
|
155
|
+
self.successful_requests += 1
|
156
|
+
|
157
|
+
def __copy__(self) -> "TokenUsageCallbackHandler":
|
158
|
+
"""Return a copy of the callback handler."""
|
159
|
+
return self
|
160
|
+
|
161
|
+
def __deepcopy__(self, memo: Any) -> "TokenUsageCallbackHandler":
|
162
|
+
"""Return a deep copy of the callback handler."""
|
163
|
+
return self
|
164
|
+
|
165
|
+
|
166
|
+
token_usage_callback_var: ContextVar[TokenUsageCallbackHandler | None] = ContextVar(
|
167
|
+
"token_usage_callback_var", default=None
|
168
|
+
)
|
169
|
+
register_configure_hook(token_usage_callback_var, True)
|
170
|
+
|
171
|
+
|
172
|
+
@contextmanager
|
173
|
+
def get_model_callback() -> Generator[TokenUsageCallbackHandler, None, None]:
|
174
|
+
cb = TokenUsageCallbackHandler()
|
175
|
+
token_usage_callback_var.set(cb)
|
176
|
+
yield cb
|
177
|
+
token_usage_callback_var.set(None)
|
janus/llm/models_info.py
CHANGED
@@ -1,115 +1,179 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Callable
|
5
5
|
|
6
6
|
from dotenv import load_dotenv
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from langchain_community.chat_models import BedrockChat
|
8
|
+
from langchain_community.llms import HuggingFaceTextGenInference
|
9
|
+
from langchain_community.llms.bedrock import Bedrock
|
9
10
|
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
11
|
+
from langchain_core.language_models import BaseLanguageModel
|
12
|
+
from langchain_openai import ChatOpenAI
|
13
|
+
|
14
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
15
|
+
from janus.prompts.prompt import (
|
16
|
+
ChatGptPromptEngine,
|
17
|
+
ClaudePromptEngine,
|
18
|
+
CoherePromptEngine,
|
19
|
+
Llama2PromptEngine,
|
20
|
+
Llama3PromptEngine,
|
21
|
+
PromptEngine,
|
22
|
+
TitanPromptEngine,
|
23
|
+
)
|
10
24
|
|
11
25
|
load_dotenv()
|
12
26
|
|
13
|
-
|
27
|
+
openai_model_reroutes = {
|
28
|
+
"gpt-4o": "gpt-4o-2024-05-13",
|
29
|
+
"gpt-4": "gpt-4-0613",
|
30
|
+
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
31
|
+
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
32
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
33
|
+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
34
|
+
}
|
35
|
+
|
36
|
+
openai_models = [
|
37
|
+
"gpt-4-0613",
|
38
|
+
"gpt-4-1106-preview",
|
39
|
+
"gpt-4-0125-preview",
|
40
|
+
"gpt-4o-2024-05-13",
|
41
|
+
"gpt-3.5-turbo-0125",
|
42
|
+
]
|
43
|
+
claude_models = [
|
44
|
+
"bedrock-claude-v2",
|
45
|
+
"bedrock-claude-instant-v1",
|
46
|
+
"bedrock-claude-haiku",
|
47
|
+
"bedrock-claude-sonnet",
|
48
|
+
]
|
49
|
+
llama2_models = [
|
50
|
+
"bedrock-llama2-70b",
|
51
|
+
"bedrock-llama2-70b-chat",
|
52
|
+
"bedrock-llama2-13b",
|
53
|
+
"bedrock-llama2-13b-chat",
|
54
|
+
]
|
55
|
+
llama3_models = [
|
56
|
+
"bedrock-llama3-8b-instruct",
|
57
|
+
"bedrock-llama3-70b-instruct",
|
58
|
+
]
|
59
|
+
titan_models = [
|
60
|
+
"bedrock-titan-text-lite",
|
61
|
+
"bedrock-titan-text-express",
|
62
|
+
"bedrock-jurassic-2-mid",
|
63
|
+
"bedrock-jurassic-2-ultra",
|
64
|
+
]
|
65
|
+
cohere_models = [
|
66
|
+
"bedrock-command-r-plus",
|
67
|
+
]
|
68
|
+
bedrock_models = [
|
69
|
+
*claude_models,
|
70
|
+
*llama2_models,
|
71
|
+
*llama3_models,
|
72
|
+
*titan_models,
|
73
|
+
*cohere_models,
|
74
|
+
]
|
75
|
+
all_models = [*openai_models, *bedrock_models]
|
76
|
+
|
77
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
|
14
78
|
"OpenAI": ChatOpenAI,
|
15
79
|
"HuggingFace": HuggingFaceTextGenInference,
|
16
80
|
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
81
|
+
"Bedrock": Bedrock,
|
82
|
+
"BedrockChat": BedrockChat,
|
17
83
|
}
|
18
84
|
|
19
85
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"mitre-falcon": "HuggingFace",
|
28
|
-
"mitre-wizard-coder": "HuggingFace",
|
86
|
+
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
87
|
+
**{m: ChatGptPromptEngine for m in openai_models},
|
88
|
+
**{m: ClaudePromptEngine for m in claude_models},
|
89
|
+
**{m: Llama2PromptEngine for m in llama2_models},
|
90
|
+
**{m: Llama3PromptEngine for m in llama3_models},
|
91
|
+
**{m: TitanPromptEngine for m in titan_models},
|
92
|
+
**{m: CoherePromptEngine for m in cohere_models},
|
29
93
|
}
|
30
94
|
|
31
|
-
_open_ai_defaults:
|
95
|
+
_open_ai_defaults: dict[str, str] = {
|
32
96
|
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
33
97
|
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
34
98
|
}
|
35
99
|
|
36
|
-
|
37
|
-
|
38
|
-
"
|
39
|
-
"
|
40
|
-
"
|
41
|
-
"
|
42
|
-
"
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
"
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
temperature=0.01,
|
59
|
-
repetition_penalty=1.03,
|
60
|
-
timeout=240,
|
61
|
-
),
|
62
|
-
"mitre-wizard-coder": dict(
|
63
|
-
inference_server_url="https://wizard-coder-34b.aip.mitre.org",
|
64
|
-
max_new_tokens=1024,
|
65
|
-
top_k=10,
|
66
|
-
top_p=0.95,
|
67
|
-
typical_p=0.95,
|
68
|
-
temperature=0.01,
|
69
|
-
repetition_penalty=1.03,
|
70
|
-
timeout=240,
|
71
|
-
),
|
100
|
+
model_identifiers = {
|
101
|
+
**{m: m for m in openai_models},
|
102
|
+
"bedrock-claude-v2": "anthropic.claude-v2",
|
103
|
+
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
104
|
+
"bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
105
|
+
"bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
106
|
+
"bedrock-llama2-70b": "meta.llama2-70b-v1",
|
107
|
+
"bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
|
108
|
+
"bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
|
109
|
+
"bedrock-llama2-13b-chat": "meta.llama2-13b-v1",
|
110
|
+
"bedrock-llama3-8b-instruct": "meta.llama3-8b-instruct-v1:0",
|
111
|
+
"bedrock-llama3-70b-instruct": "meta.llama3-70b-instruct-v1:0",
|
112
|
+
"bedrock-titan-text-lite": "amazon.titan-text-lite-v1",
|
113
|
+
"bedrock-titan-text-express": "amazon.titan-text-express-v1",
|
114
|
+
"bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
|
115
|
+
"bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
|
116
|
+
"bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
|
117
|
+
}
|
118
|
+
|
119
|
+
MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
|
120
|
+
k: (dict(model_name=k) if k in openai_models else dict(model_id=v))
|
121
|
+
for k, v in model_identifiers.items()
|
72
122
|
}
|
73
123
|
|
74
124
|
DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
75
125
|
|
76
126
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
77
127
|
|
78
|
-
|
79
|
-
"
|
128
|
+
MODEL_TYPES: dict[str, PromptEngine] = {
|
129
|
+
**{model_identifiers[m]: "OpenAI" for m in openai_models},
|
130
|
+
**{model_identifiers[m]: "BedrockChat" for m in bedrock_models},
|
131
|
+
}
|
132
|
+
|
133
|
+
TOKEN_LIMITS: dict[str, int] = {
|
80
134
|
"gpt-4-32k": 32_768,
|
135
|
+
"gpt-4-0613": 8192,
|
81
136
|
"gpt-4-1106-preview": 128_000,
|
82
|
-
"gpt-
|
83
|
-
"gpt-
|
84
|
-
"
|
137
|
+
"gpt-4-0125-preview": 128_000,
|
138
|
+
"gpt-4o-2024-05-13": 128_000,
|
139
|
+
"gpt-3.5-turbo-0125": 16_384,
|
85
140
|
"text-embedding-ada-002": 8191,
|
86
141
|
"gpt4all": 16_384,
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
"
|
91
|
-
"
|
92
|
-
"
|
93
|
-
"
|
94
|
-
"
|
95
|
-
"
|
96
|
-
"
|
97
|
-
"
|
142
|
+
"anthropic.claude-v2": 100_000,
|
143
|
+
"anthropic.claude-instant-v1": 100_000,
|
144
|
+
"anthropic.claude-3-haiku-20240307-v1:0": 248_000,
|
145
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
|
146
|
+
"meta.llama2-70b-v1": 4096,
|
147
|
+
"meta.llama2-70b-chat-v1": 4096,
|
148
|
+
"meta.llama2-13b-chat-v1": 4096,
|
149
|
+
"meta.llama2-13b-v1": 4096,
|
150
|
+
"meta.llama3-8b-instruct-v1:0": 8000,
|
151
|
+
"meta.llama3-70b-instruct-v1:0": 8000,
|
152
|
+
"amazon.titan-text-lite-v1": 4096,
|
153
|
+
"amazon.titan-text-express-v1": 8192,
|
154
|
+
"ai21.j2-mid-v1": 8192,
|
155
|
+
"ai21.j2-ultra-v1": 8192,
|
156
|
+
"cohere.command-r-plus-v1:0": 128_000,
|
98
157
|
}
|
99
158
|
|
100
159
|
|
101
|
-
def load_model(model_name: str) ->
|
160
|
+
def load_model(model_name: str) -> tuple[BaseLanguageModel, int, dict[str, float]]:
|
102
161
|
if not MODEL_CONFIG_DIR.exists():
|
103
162
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
104
163
|
model_config_file = MODEL_CONFIG_DIR / f"{model_name}.json"
|
105
164
|
if not model_config_file.exists():
|
106
165
|
if model_name not in DEFAULT_MODELS:
|
107
|
-
|
166
|
+
if model_name in openai_model_reroutes:
|
167
|
+
model_name = openai_model_reroutes[model_name]
|
168
|
+
else:
|
169
|
+
raise ValueError(f"Error: could not find model {model_name}")
|
108
170
|
model_config = {
|
109
171
|
"model_type": MODEL_TYPES[model_name],
|
110
172
|
"model_args": MODEL_DEFAULT_ARGUMENTS[model_name],
|
111
|
-
"token_limit": TOKEN_LIMITS.get(model_name, 4096),
|
112
|
-
"model_cost":
|
173
|
+
"token_limit": TOKEN_LIMITS.get(model_identifiers[model_name], 4096),
|
174
|
+
"model_cost": COST_PER_1K_TOKENS.get(
|
175
|
+
model_identifiers[model_name], {"input": 0, "output": 0}
|
176
|
+
),
|
113
177
|
}
|
114
178
|
with open(model_config_file, "w") as f:
|
115
179
|
json.dump(model_config, f)
|
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
from sacrebleu import sentence_bleu
|
4
|
+
|
5
|
+
from ..bleu import bleu
|
6
|
+
|
7
|
+
|
8
|
+
class TestBLEU(unittest.TestCase):
|
9
|
+
def setUp(self):
|
10
|
+
self.target_text = "This is a source text."
|
11
|
+
self.reference_text = "This is a destination text."
|
12
|
+
|
13
|
+
def test_bleu(self):
|
14
|
+
"""Test the BLEU score calculation."""
|
15
|
+
function_score = (
|
16
|
+
sentence_bleu(
|
17
|
+
self.target_text,
|
18
|
+
[self.reference_text],
|
19
|
+
).score
|
20
|
+
/ 100.0
|
21
|
+
)
|
22
|
+
expected_score = bleu(
|
23
|
+
self.target_text,
|
24
|
+
self.reference_text,
|
25
|
+
)
|
26
|
+
self.assertEqual(function_score, expected_score)
|
27
|
+
|
28
|
+
def test_bleu_with_s_flag(self):
|
29
|
+
"""Test the BLEU score calculation with the -S flag."""
|
30
|
+
function_score = (
|
31
|
+
sentence_bleu(
|
32
|
+
self.target_text,
|
33
|
+
[self.reference_text],
|
34
|
+
).score
|
35
|
+
/ 100.0
|
36
|
+
)
|
37
|
+
score_with_s_flag = bleu(
|
38
|
+
self.target_text,
|
39
|
+
self.reference_text,
|
40
|
+
use_strings=True, # Mimics -S
|
41
|
+
)
|
42
|
+
self.assertEqual(function_score, score_with_s_flag)
|
43
|
+
|
44
|
+
def test_bleu_invalid_target_type(self):
|
45
|
+
"""Test the BLEU score calculation with invalid source text type."""
|
46
|
+
with self.assertRaises(TypeError):
|
47
|
+
sentence_bleu(123, [self.reference_text])
|
48
|
+
|
49
|
+
def test_bleu_invalid_reference_type(self):
|
50
|
+
"""Test the BLEU score calculation with invalid destination text type."""
|
51
|
+
with self.assertRaises(TypeError):
|
52
|
+
sentence_bleu(self.target_text, 123)
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
unittest.main()
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
from sacrebleu import sentence_chrf
|
4
|
+
|
5
|
+
from ..chrf import chrf
|
6
|
+
|
7
|
+
|
8
|
+
class TestChrF(unittest.TestCase):
|
9
|
+
def setUp(self):
|
10
|
+
self.target_text = "This is a source text."
|
11
|
+
self.reference_text = "This is a destination text."
|
12
|
+
self.char_order = 6
|
13
|
+
self.word_order = 2
|
14
|
+
self.beta = 2.0
|
15
|
+
|
16
|
+
def test_chrf_custom_params(self):
|
17
|
+
"""Test the chrf function with custom parameters."""
|
18
|
+
function_score = chrf(
|
19
|
+
self.target_text,
|
20
|
+
self.reference_text,
|
21
|
+
self.char_order,
|
22
|
+
self.word_order,
|
23
|
+
self.beta,
|
24
|
+
)
|
25
|
+
score = sentence_chrf(
|
26
|
+
hypothesis=self.target_text,
|
27
|
+
references=[self.reference_text],
|
28
|
+
char_order=self.char_order,
|
29
|
+
word_order=self.word_order,
|
30
|
+
beta=self.beta,
|
31
|
+
)
|
32
|
+
expected_score = float(score.score) / 100.0
|
33
|
+
self.assertEqual(function_score, expected_score)
|
34
|
+
|
35
|
+
def test_chrf_with_s_flag(self):
|
36
|
+
"""Test the CHRF score calculation with the -S flag."""
|
37
|
+
function_score = sentence_chrf(
|
38
|
+
hypothesis=self.target_text,
|
39
|
+
references=[self.reference_text],
|
40
|
+
char_order=self.char_order,
|
41
|
+
word_order=self.word_order,
|
42
|
+
beta=self.beta,
|
43
|
+
)
|
44
|
+
function_score = float(function_score.score) / 100.0
|
45
|
+
score_with_s_flag = chrf(
|
46
|
+
self.target_text,
|
47
|
+
self.reference_text,
|
48
|
+
self.char_order,
|
49
|
+
self.word_order,
|
50
|
+
self.beta,
|
51
|
+
use_strings=True, # Mimics -S
|
52
|
+
)
|
53
|
+
self.assertEqual(function_score, score_with_s_flag)
|
54
|
+
|
55
|
+
def test_chrf_invalid_target_type(self):
|
56
|
+
"""Test the chrf function with invalid source text type."""
|
57
|
+
with self.assertRaises(TypeError):
|
58
|
+
chrf(123, self.reference_text, self.char_order, self.word_order, self.beta)
|
59
|
+
|
60
|
+
def test_chrf_invalid_reference_type(self):
|
61
|
+
"""Test the chrf function with invalid destination text type."""
|
62
|
+
with self.assertRaises(TypeError):
|
63
|
+
chrf(self.target_text, 123, self.char_order, self.word_order, self.beta)
|
64
|
+
|
65
|
+
|
66
|
+
if __name__ == "__main__":
|
67
|
+
unittest.main()
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# FILEPATH: /Users/mdoyle/projects/janus/janus/metrics/tests/test_file_pairing.py
|
2
|
+
|
3
|
+
import unittest
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from ..file_pairing import (
|
7
|
+
FILE_PAIRING_METHODS,
|
8
|
+
pair_by_file,
|
9
|
+
pair_by_line,
|
10
|
+
pair_by_line_comment,
|
11
|
+
register_pairing_method,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class TestFilePairing(unittest.TestCase):
|
16
|
+
def setUp(self):
|
17
|
+
self.src = "Hello\nWorld"
|
18
|
+
self.cmp = "Hello\nPython"
|
19
|
+
self.state = {
|
20
|
+
"token_limit": 100,
|
21
|
+
"llm": None,
|
22
|
+
"lang": "python",
|
23
|
+
"target_file": self.src,
|
24
|
+
"cmp_file": self.cmp,
|
25
|
+
}
|
26
|
+
|
27
|
+
def test_register_pairing_method(self):
|
28
|
+
@register_pairing_method(name="test")
|
29
|
+
def test_method(src, cmp, state):
|
30
|
+
return [(src, cmp)]
|
31
|
+
|
32
|
+
self.assertIn("test", FILE_PAIRING_METHODS)
|
33
|
+
|
34
|
+
def test_pair_by_file(self):
|
35
|
+
expected = [(self.src, self.cmp)]
|
36
|
+
result = pair_by_file(self.src, self.cmp)
|
37
|
+
self.assertEqual(result, expected)
|
38
|
+
|
39
|
+
def test_pair_by_line(self):
|
40
|
+
expected = [("Hello", "Hello"), ("World", "Python")]
|
41
|
+
result = pair_by_line(self.src, self.cmp)
|
42
|
+
self.assertEqual(result, expected)
|
43
|
+
|
44
|
+
def test_pair_by_line_comment(self):
|
45
|
+
# This test assumes that the source and comparison files have comments on the
|
46
|
+
# same lines
|
47
|
+
# You may need to adjust this test based on your specific use case
|
48
|
+
self.target = Path(__file__).parent / "target.py"
|
49
|
+
self.reference = Path(__file__).parent / "reference.py"
|
50
|
+
kwargs = {
|
51
|
+
"token_limit": 100,
|
52
|
+
"llm": None,
|
53
|
+
"lang": "python",
|
54
|
+
"target_file": self.target,
|
55
|
+
"reference_file": self.reference,
|
56
|
+
}
|
57
|
+
expected = [("# Hello\n", "# Hello\n")]
|
58
|
+
result = pair_by_line_comment(self.src, self.cmp, **kwargs)
|
59
|
+
self.assertEqual(result, expected)
|