janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +130 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +19 -14
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +165 -72
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
- janus_llm-2.0.1.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
import threading
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from contextvars import ContextVar
|
4
|
+
from typing import Any, Generator
|
5
|
+
|
6
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
7
|
+
from langchain_core.messages import AIMessage
|
8
|
+
from langchain_core.outputs import ChatGeneration, LLMResult
|
9
|
+
from langchain_core.tracers.context import register_configure_hook
|
10
|
+
|
11
|
+
from janus.utils.logger import create_logger
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
# Updated 2024-06-21
|
17
|
+
COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
|
18
|
+
"gpt-3.5-turbo-0125": {"input": 0.0005, "output": 0.0015},
|
19
|
+
"gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
|
20
|
+
"gpt-4-0125-preview": {"input": 0.01, "output": 0.03},
|
21
|
+
"gpt-4-0613": {"input": 0.03, "output": 0.06},
|
22
|
+
"gpt-4o-2024-05-13": {"input": 0.005, "output": 0.015},
|
23
|
+
"anthropic.claude-v2": {"input": 0.008, "output": 0.024},
|
24
|
+
"anthropic.claude-instant-v1": {"input": 0.0008, "output": 0.0024},
|
25
|
+
"anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125},
|
26
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": {"input": 0.003, "output": 0.015},
|
27
|
+
"meta.llama2-13b-chat-v1": {"input": 0.00075, "output": 0.001},
|
28
|
+
"meta.llama2-70b-chat-v1": {"input": 0.00195, "output": 0.00256},
|
29
|
+
"meta.llama2-13b-v1": {"input": 0.0, "output": 0.0},
|
30
|
+
"meta.llama2-70b-v1": {"input": 0.00265, "output": 0.0035},
|
31
|
+
"meta.llama3-8b-instruct-v1:0": {"input": 0.0003, "output": 0.0006},
|
32
|
+
"meta.llama3-70b-instruct-v1:0": {"input": 0.00265, "output": 0.0035},
|
33
|
+
"amazon.titan-text-lite-v1": {"input": 0.00015, "output": 0.0002},
|
34
|
+
"amazon.titan-text-express-v1": {"input": 0.0002, "output": 0.0006},
|
35
|
+
"ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
|
36
|
+
"ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
|
37
|
+
"cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
def _get_token_cost(
|
42
|
+
prompt_tokens: int, completion_tokens: int, model_id: str | None
|
43
|
+
) -> float:
|
44
|
+
"""Get the cost of tokens according to model ID"""
|
45
|
+
if model_id not in COST_PER_1K_TOKENS:
|
46
|
+
raise ValueError(
|
47
|
+
f"Unknown model: {model_id}. Please provide a valid model name."
|
48
|
+
f" Known models are: {', '.join(COST_PER_1K_TOKENS.keys())}"
|
49
|
+
)
|
50
|
+
model_cost = COST_PER_1K_TOKENS[model_id]
|
51
|
+
input_cost = (prompt_tokens / 1000.0) * model_cost["input"]
|
52
|
+
output_cost = (completion_tokens / 1000.0) * model_cost["output"]
|
53
|
+
return input_cost + output_cost
|
54
|
+
|
55
|
+
|
56
|
+
class TokenUsageCallbackHandler(BaseCallbackHandler):
|
57
|
+
"""Callback Handler that tracks metadata on model cost, retries, etc.
|
58
|
+
Based on https://github.com/langchain-ai/langchain/blob/master/libs
|
59
|
+
/community/langchain_community/callbacks/openai_info.py
|
60
|
+
"""
|
61
|
+
|
62
|
+
total_tokens: int = 0
|
63
|
+
prompt_tokens: int = 0
|
64
|
+
completion_tokens: int = 0
|
65
|
+
successful_requests: int = 0
|
66
|
+
total_cost: float = 0.0
|
67
|
+
|
68
|
+
def __init__(self) -> None:
|
69
|
+
super().__init__()
|
70
|
+
self._lock = threading.Lock()
|
71
|
+
|
72
|
+
def __repr__(self) -> str:
|
73
|
+
return (
|
74
|
+
f"Tokens Used: {self.total_tokens}\n"
|
75
|
+
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
76
|
+
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
77
|
+
f"Successful Requests: {self.successful_requests}\n"
|
78
|
+
f"Total Cost (USD): ${self.total_cost}"
|
79
|
+
)
|
80
|
+
|
81
|
+
@property
|
82
|
+
def always_verbose(self) -> bool:
|
83
|
+
"""Whether to call verbose callbacks even if verbose is False."""
|
84
|
+
return True
|
85
|
+
|
86
|
+
def on_chat_model_start(self, *args, **kwargs):
|
87
|
+
pass
|
88
|
+
|
89
|
+
def on_llm_start(
|
90
|
+
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
91
|
+
) -> None:
|
92
|
+
"""Print out the prompts."""
|
93
|
+
pass
|
94
|
+
|
95
|
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
96
|
+
"""Print out the token."""
|
97
|
+
pass
|
98
|
+
|
99
|
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
100
|
+
"""Collect token usage."""
|
101
|
+
# Check for usage_metadata (langchain-core >= 0.2.2)
|
102
|
+
try:
|
103
|
+
generation = response.generations[0][0]
|
104
|
+
except IndexError:
|
105
|
+
generation = None
|
106
|
+
if isinstance(generation, ChatGeneration):
|
107
|
+
try:
|
108
|
+
message = generation.message
|
109
|
+
if isinstance(message, AIMessage):
|
110
|
+
usage_metadata = message.usage_metadata
|
111
|
+
else:
|
112
|
+
usage_metadata = None
|
113
|
+
except AttributeError:
|
114
|
+
usage_metadata = None
|
115
|
+
else:
|
116
|
+
usage_metadata = None
|
117
|
+
if usage_metadata:
|
118
|
+
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
119
|
+
completion_tokens = usage_metadata["output_tokens"]
|
120
|
+
prompt_tokens = usage_metadata["input_tokens"]
|
121
|
+
if response.llm_output is None:
|
122
|
+
# model name (and therefore cost) is unavailable in
|
123
|
+
# streaming responses
|
124
|
+
model_name = ""
|
125
|
+
else:
|
126
|
+
model_name = response.llm_output.get("model_name", "")
|
127
|
+
|
128
|
+
else:
|
129
|
+
if response.llm_output is None:
|
130
|
+
return None
|
131
|
+
|
132
|
+
if "token_usage" not in response.llm_output:
|
133
|
+
with self._lock:
|
134
|
+
self.successful_requests += 1
|
135
|
+
return None
|
136
|
+
|
137
|
+
# compute tokens and cost for this request
|
138
|
+
token_usage = response.llm_output["token_usage"]
|
139
|
+
completion_tokens = token_usage.get("completion_tokens", 0)
|
140
|
+
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
141
|
+
model_name = response.llm_output.get("model_name", "")
|
142
|
+
|
143
|
+
total_cost = _get_token_cost(
|
144
|
+
prompt_tokens=prompt_tokens,
|
145
|
+
completion_tokens=completion_tokens,
|
146
|
+
model_id=model_name,
|
147
|
+
)
|
148
|
+
|
149
|
+
# update shared state behind lock
|
150
|
+
with self._lock:
|
151
|
+
self.total_cost += total_cost
|
152
|
+
self.total_tokens += token_usage.get("total_tokens", 0)
|
153
|
+
self.prompt_tokens += prompt_tokens
|
154
|
+
self.completion_tokens += completion_tokens
|
155
|
+
self.successful_requests += 1
|
156
|
+
|
157
|
+
def __copy__(self) -> "TokenUsageCallbackHandler":
|
158
|
+
"""Return a copy of the callback handler."""
|
159
|
+
return self
|
160
|
+
|
161
|
+
def __deepcopy__(self, memo: Any) -> "TokenUsageCallbackHandler":
|
162
|
+
"""Return a deep copy of the callback handler."""
|
163
|
+
return self
|
164
|
+
|
165
|
+
|
166
|
+
token_usage_callback_var: ContextVar[TokenUsageCallbackHandler | None] = ContextVar(
|
167
|
+
"token_usage_callback_var", default=None
|
168
|
+
)
|
169
|
+
register_configure_hook(token_usage_callback_var, True)
|
170
|
+
|
171
|
+
|
172
|
+
@contextmanager
|
173
|
+
def get_model_callback() -> Generator[TokenUsageCallbackHandler, None, None]:
|
174
|
+
cb = TokenUsageCallbackHandler()
|
175
|
+
token_usage_callback_var.set(cb)
|
176
|
+
yield cb
|
177
|
+
token_usage_callback_var.set(None)
|
janus/llm/models_info.py
CHANGED
@@ -1,115 +1,208 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Callable
|
5
5
|
|
6
6
|
from dotenv import load_dotenv
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
7
|
+
from langchain_community.llms import HuggingFaceTextGenInference
|
8
|
+
from langchain_core.language_models import BaseLanguageModel
|
9
|
+
from langchain_openai import ChatOpenAI
|
10
|
+
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS
|
12
|
+
from janus.prompts.prompt import (
|
13
|
+
ChatGptPromptEngine,
|
14
|
+
ClaudePromptEngine,
|
15
|
+
CoherePromptEngine,
|
16
|
+
Llama2PromptEngine,
|
17
|
+
Llama3PromptEngine,
|
18
|
+
PromptEngine,
|
19
|
+
TitanPromptEngine,
|
20
|
+
)
|
21
|
+
|
22
|
+
from ..utils.logger import create_logger
|
23
|
+
|
24
|
+
log = create_logger(__name__)
|
25
|
+
|
26
|
+
try:
|
27
|
+
from langchain_community.chat_models import BedrockChat
|
28
|
+
from langchain_community.llms.bedrock import Bedrock
|
29
|
+
except ImportError:
|
30
|
+
log.warning(
|
31
|
+
"Could not import LangChain's Bedrock Client. If you would like to use Bedrock "
|
32
|
+
"models, please install LangChain's Bedrock Client by running 'pip install "
|
33
|
+
"janus-llm[bedrock]' or poetry install -E bedrock."
|
34
|
+
)
|
35
|
+
|
36
|
+
try:
|
37
|
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
38
|
+
except ImportError:
|
39
|
+
log.warning(
|
40
|
+
"Could not import LangChain's HuggingFace Pipeline Client. If you would like to "
|
41
|
+
"use HuggingFace models, please install LangChain's HuggingFace Pipeline Client "
|
42
|
+
"by running 'pip install janus-llm[hf-local]' or poetry install -E hf-local."
|
43
|
+
)
|
44
|
+
|
10
45
|
|
11
46
|
load_dotenv()
|
12
47
|
|
13
|
-
|
48
|
+
openai_model_reroutes = {
|
49
|
+
"gpt-4o": "gpt-4o-2024-05-13",
|
50
|
+
"gpt-4": "gpt-4-0613",
|
51
|
+
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
52
|
+
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
53
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
54
|
+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
|
55
|
+
}
|
56
|
+
|
57
|
+
openai_models = [
|
58
|
+
"gpt-4-0613",
|
59
|
+
"gpt-4-1106-preview",
|
60
|
+
"gpt-4-0125-preview",
|
61
|
+
"gpt-4o-2024-05-13",
|
62
|
+
"gpt-3.5-turbo-0125",
|
63
|
+
]
|
64
|
+
claude_models = [
|
65
|
+
"bedrock-claude-v2",
|
66
|
+
"bedrock-claude-instant-v1",
|
67
|
+
"bedrock-claude-haiku",
|
68
|
+
"bedrock-claude-sonnet",
|
69
|
+
]
|
70
|
+
llama2_models = [
|
71
|
+
"bedrock-llama2-70b",
|
72
|
+
"bedrock-llama2-70b-chat",
|
73
|
+
"bedrock-llama2-13b",
|
74
|
+
"bedrock-llama2-13b-chat",
|
75
|
+
]
|
76
|
+
llama3_models = [
|
77
|
+
"bedrock-llama3-8b-instruct",
|
78
|
+
"bedrock-llama3-70b-instruct",
|
79
|
+
]
|
80
|
+
titan_models = [
|
81
|
+
"bedrock-titan-text-lite",
|
82
|
+
"bedrock-titan-text-express",
|
83
|
+
"bedrock-jurassic-2-mid",
|
84
|
+
"bedrock-jurassic-2-ultra",
|
85
|
+
]
|
86
|
+
cohere_models = [
|
87
|
+
"bedrock-command-r-plus",
|
88
|
+
]
|
89
|
+
bedrock_models = [
|
90
|
+
*claude_models,
|
91
|
+
*llama2_models,
|
92
|
+
*llama3_models,
|
93
|
+
*titan_models,
|
94
|
+
*cohere_models,
|
95
|
+
]
|
96
|
+
all_models = [*openai_models, *bedrock_models]
|
97
|
+
|
98
|
+
MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
|
14
99
|
"OpenAI": ChatOpenAI,
|
15
100
|
"HuggingFace": HuggingFaceTextGenInference,
|
16
|
-
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
17
101
|
}
|
18
102
|
|
103
|
+
try:
|
104
|
+
MODEL_TYPE_CONSTRUCTORS.update(
|
105
|
+
{
|
106
|
+
"HuggingFaceLocal": HuggingFacePipeline.from_model_id,
|
107
|
+
"Bedrock": Bedrock,
|
108
|
+
"BedrockChat": BedrockChat,
|
109
|
+
}
|
110
|
+
)
|
111
|
+
except NameError:
|
112
|
+
pass
|
113
|
+
|
19
114
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
"mitre-falcon": "HuggingFace",
|
28
|
-
"mitre-wizard-coder": "HuggingFace",
|
115
|
+
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
116
|
+
**{m: ChatGptPromptEngine for m in openai_models},
|
117
|
+
**{m: ClaudePromptEngine for m in claude_models},
|
118
|
+
**{m: Llama2PromptEngine for m in llama2_models},
|
119
|
+
**{m: Llama3PromptEngine for m in llama3_models},
|
120
|
+
**{m: TitanPromptEngine for m in titan_models},
|
121
|
+
**{m: CoherePromptEngine for m in cohere_models},
|
29
122
|
}
|
30
123
|
|
31
|
-
_open_ai_defaults:
|
124
|
+
_open_ai_defaults: dict[str, str] = {
|
32
125
|
"openai_api_key": os.getenv("OPENAI_API_KEY"),
|
33
126
|
"openai_organization": os.getenv("OPENAI_ORG_ID"),
|
34
127
|
}
|
35
128
|
|
36
|
-
|
37
|
-
|
38
|
-
"
|
39
|
-
"
|
40
|
-
"
|
41
|
-
"
|
42
|
-
"
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
"
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
temperature=0.01,
|
59
|
-
repetition_penalty=1.03,
|
60
|
-
timeout=240,
|
61
|
-
),
|
62
|
-
"mitre-wizard-coder": dict(
|
63
|
-
inference_server_url="https://wizard-coder-34b.aip.mitre.org",
|
64
|
-
max_new_tokens=1024,
|
65
|
-
top_k=10,
|
66
|
-
top_p=0.95,
|
67
|
-
typical_p=0.95,
|
68
|
-
temperature=0.01,
|
69
|
-
repetition_penalty=1.03,
|
70
|
-
timeout=240,
|
71
|
-
),
|
129
|
+
model_identifiers = {
|
130
|
+
**{m: m for m in openai_models},
|
131
|
+
"bedrock-claude-v2": "anthropic.claude-v2",
|
132
|
+
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
133
|
+
"bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
134
|
+
"bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
135
|
+
"bedrock-llama2-70b": "meta.llama2-70b-v1",
|
136
|
+
"bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
|
137
|
+
"bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
|
138
|
+
"bedrock-llama2-13b-chat": "meta.llama2-13b-v1",
|
139
|
+
"bedrock-llama3-8b-instruct": "meta.llama3-8b-instruct-v1:0",
|
140
|
+
"bedrock-llama3-70b-instruct": "meta.llama3-70b-instruct-v1:0",
|
141
|
+
"bedrock-titan-text-lite": "amazon.titan-text-lite-v1",
|
142
|
+
"bedrock-titan-text-express": "amazon.titan-text-express-v1",
|
143
|
+
"bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
|
144
|
+
"bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
|
145
|
+
"bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
|
146
|
+
}
|
147
|
+
|
148
|
+
MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
|
149
|
+
k: (dict(model_name=k) if k in openai_models else dict(model_id=v))
|
150
|
+
for k, v in model_identifiers.items()
|
72
151
|
}
|
73
152
|
|
74
153
|
DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
75
154
|
|
76
155
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
77
156
|
|
78
|
-
|
79
|
-
"
|
157
|
+
MODEL_TYPES: dict[str, PromptEngine] = {
|
158
|
+
**{m: "OpenAI" for m in openai_models},
|
159
|
+
**{m: "BedrockChat" for m in bedrock_models},
|
160
|
+
}
|
161
|
+
|
162
|
+
TOKEN_LIMITS: dict[str, int] = {
|
80
163
|
"gpt-4-32k": 32_768,
|
164
|
+
"gpt-4-0613": 8192,
|
81
165
|
"gpt-4-1106-preview": 128_000,
|
82
|
-
"gpt-
|
83
|
-
"gpt-
|
84
|
-
"
|
166
|
+
"gpt-4-0125-preview": 128_000,
|
167
|
+
"gpt-4o-2024-05-13": 128_000,
|
168
|
+
"gpt-3.5-turbo-0125": 16_384,
|
85
169
|
"text-embedding-ada-002": 8191,
|
86
170
|
"gpt4all": 16_384,
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
"
|
91
|
-
"
|
92
|
-
"
|
93
|
-
"
|
94
|
-
"
|
95
|
-
"
|
96
|
-
"
|
97
|
-
"
|
171
|
+
"anthropic.claude-v2": 100_000,
|
172
|
+
"anthropic.claude-instant-v1": 100_000,
|
173
|
+
"anthropic.claude-3-haiku-20240307-v1:0": 248_000,
|
174
|
+
"anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
|
175
|
+
"meta.llama2-70b-v1": 4096,
|
176
|
+
"meta.llama2-70b-chat-v1": 4096,
|
177
|
+
"meta.llama2-13b-chat-v1": 4096,
|
178
|
+
"meta.llama2-13b-v1": 4096,
|
179
|
+
"meta.llama3-8b-instruct-v1:0": 8000,
|
180
|
+
"meta.llama3-70b-instruct-v1:0": 8000,
|
181
|
+
"amazon.titan-text-lite-v1": 4096,
|
182
|
+
"amazon.titan-text-express-v1": 8192,
|
183
|
+
"ai21.j2-mid-v1": 8192,
|
184
|
+
"ai21.j2-ultra-v1": 8192,
|
185
|
+
"cohere.command-r-plus-v1:0": 128_000,
|
98
186
|
}
|
99
187
|
|
100
188
|
|
101
|
-
def load_model(model_name: str) ->
|
189
|
+
def load_model(model_name: str) -> tuple[BaseLanguageModel, int, dict[str, float]]:
|
102
190
|
if not MODEL_CONFIG_DIR.exists():
|
103
191
|
MODEL_CONFIG_DIR.mkdir(parents=True)
|
104
192
|
model_config_file = MODEL_CONFIG_DIR / f"{model_name}.json"
|
105
193
|
if not model_config_file.exists():
|
106
194
|
if model_name not in DEFAULT_MODELS:
|
107
|
-
|
195
|
+
if model_name in openai_model_reroutes:
|
196
|
+
model_name = openai_model_reroutes[model_name]
|
197
|
+
else:
|
198
|
+
raise ValueError(f"Error: could not find model {model_name}")
|
108
199
|
model_config = {
|
109
200
|
"model_type": MODEL_TYPES[model_name],
|
110
201
|
"model_args": MODEL_DEFAULT_ARGUMENTS[model_name],
|
111
|
-
"token_limit": TOKEN_LIMITS.get(model_name, 4096),
|
112
|
-
"model_cost":
|
202
|
+
"token_limit": TOKEN_LIMITS.get(model_identifiers[model_name], 4096),
|
203
|
+
"model_cost": COST_PER_1K_TOKENS.get(
|
204
|
+
model_identifiers[model_name], {"input": 0, "output": 0}
|
205
|
+
),
|
113
206
|
}
|
114
207
|
with open(model_config_file, "w") as f:
|
115
208
|
json.dump(model_config, f)
|
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
from sacrebleu import sentence_bleu
|
4
|
+
|
5
|
+
from ..bleu import bleu
|
6
|
+
|
7
|
+
|
8
|
+
class TestBLEU(unittest.TestCase):
|
9
|
+
def setUp(self):
|
10
|
+
self.target_text = "This is a source text."
|
11
|
+
self.reference_text = "This is a destination text."
|
12
|
+
|
13
|
+
def test_bleu(self):
|
14
|
+
"""Test the BLEU score calculation."""
|
15
|
+
function_score = (
|
16
|
+
sentence_bleu(
|
17
|
+
self.target_text,
|
18
|
+
[self.reference_text],
|
19
|
+
).score
|
20
|
+
/ 100.0
|
21
|
+
)
|
22
|
+
expected_score = bleu(
|
23
|
+
self.target_text,
|
24
|
+
self.reference_text,
|
25
|
+
)
|
26
|
+
self.assertEqual(function_score, expected_score)
|
27
|
+
|
28
|
+
def test_bleu_with_s_flag(self):
|
29
|
+
"""Test the BLEU score calculation with the -S flag."""
|
30
|
+
function_score = (
|
31
|
+
sentence_bleu(
|
32
|
+
self.target_text,
|
33
|
+
[self.reference_text],
|
34
|
+
).score
|
35
|
+
/ 100.0
|
36
|
+
)
|
37
|
+
score_with_s_flag = bleu(
|
38
|
+
self.target_text,
|
39
|
+
self.reference_text,
|
40
|
+
use_strings=True, # Mimics -S
|
41
|
+
)
|
42
|
+
self.assertEqual(function_score, score_with_s_flag)
|
43
|
+
|
44
|
+
def test_bleu_invalid_target_type(self):
|
45
|
+
"""Test the BLEU score calculation with invalid source text type."""
|
46
|
+
with self.assertRaises(TypeError):
|
47
|
+
sentence_bleu(123, [self.reference_text])
|
48
|
+
|
49
|
+
def test_bleu_invalid_reference_type(self):
|
50
|
+
"""Test the BLEU score calculation with invalid destination text type."""
|
51
|
+
with self.assertRaises(TypeError):
|
52
|
+
sentence_bleu(self.target_text, 123)
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
unittest.main()
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
from sacrebleu import sentence_chrf
|
4
|
+
|
5
|
+
from ..chrf import chrf
|
6
|
+
|
7
|
+
|
8
|
+
class TestChrF(unittest.TestCase):
|
9
|
+
def setUp(self):
|
10
|
+
self.target_text = "This is a source text."
|
11
|
+
self.reference_text = "This is a destination text."
|
12
|
+
self.char_order = 6
|
13
|
+
self.word_order = 2
|
14
|
+
self.beta = 2.0
|
15
|
+
|
16
|
+
def test_chrf_custom_params(self):
|
17
|
+
"""Test the chrf function with custom parameters."""
|
18
|
+
function_score = chrf(
|
19
|
+
self.target_text,
|
20
|
+
self.reference_text,
|
21
|
+
self.char_order,
|
22
|
+
self.word_order,
|
23
|
+
self.beta,
|
24
|
+
)
|
25
|
+
score = sentence_chrf(
|
26
|
+
hypothesis=self.target_text,
|
27
|
+
references=[self.reference_text],
|
28
|
+
char_order=self.char_order,
|
29
|
+
word_order=self.word_order,
|
30
|
+
beta=self.beta,
|
31
|
+
)
|
32
|
+
expected_score = float(score.score) / 100.0
|
33
|
+
self.assertEqual(function_score, expected_score)
|
34
|
+
|
35
|
+
def test_chrf_with_s_flag(self):
|
36
|
+
"""Test the CHRF score calculation with the -S flag."""
|
37
|
+
function_score = sentence_chrf(
|
38
|
+
hypothesis=self.target_text,
|
39
|
+
references=[self.reference_text],
|
40
|
+
char_order=self.char_order,
|
41
|
+
word_order=self.word_order,
|
42
|
+
beta=self.beta,
|
43
|
+
)
|
44
|
+
function_score = float(function_score.score) / 100.0
|
45
|
+
score_with_s_flag = chrf(
|
46
|
+
self.target_text,
|
47
|
+
self.reference_text,
|
48
|
+
self.char_order,
|
49
|
+
self.word_order,
|
50
|
+
self.beta,
|
51
|
+
use_strings=True, # Mimics -S
|
52
|
+
)
|
53
|
+
self.assertEqual(function_score, score_with_s_flag)
|
54
|
+
|
55
|
+
def test_chrf_invalid_target_type(self):
|
56
|
+
"""Test the chrf function with invalid source text type."""
|
57
|
+
with self.assertRaises(TypeError):
|
58
|
+
chrf(123, self.reference_text, self.char_order, self.word_order, self.beta)
|
59
|
+
|
60
|
+
def test_chrf_invalid_reference_type(self):
|
61
|
+
"""Test the chrf function with invalid destination text type."""
|
62
|
+
with self.assertRaises(TypeError):
|
63
|
+
chrf(self.target_text, 123, self.char_order, self.word_order, self.beta)
|
64
|
+
|
65
|
+
|
66
|
+
if __name__ == "__main__":
|
67
|
+
unittest.main()
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# FILEPATH: /Users/mdoyle/projects/janus/janus/metrics/tests/test_file_pairing.py
|
2
|
+
|
3
|
+
import unittest
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from ..file_pairing import (
|
7
|
+
FILE_PAIRING_METHODS,
|
8
|
+
pair_by_file,
|
9
|
+
pair_by_line,
|
10
|
+
pair_by_line_comment,
|
11
|
+
register_pairing_method,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class TestFilePairing(unittest.TestCase):
|
16
|
+
def setUp(self):
|
17
|
+
self.src = "Hello\nWorld"
|
18
|
+
self.cmp = "Hello\nPython"
|
19
|
+
self.state = {
|
20
|
+
"token_limit": 100,
|
21
|
+
"llm": None,
|
22
|
+
"lang": "python",
|
23
|
+
"target_file": self.src,
|
24
|
+
"cmp_file": self.cmp,
|
25
|
+
}
|
26
|
+
|
27
|
+
def test_register_pairing_method(self):
|
28
|
+
@register_pairing_method(name="test")
|
29
|
+
def test_method(src, cmp, state):
|
30
|
+
return [(src, cmp)]
|
31
|
+
|
32
|
+
self.assertIn("test", FILE_PAIRING_METHODS)
|
33
|
+
|
34
|
+
def test_pair_by_file(self):
|
35
|
+
expected = [(self.src, self.cmp)]
|
36
|
+
result = pair_by_file(self.src, self.cmp)
|
37
|
+
self.assertEqual(result, expected)
|
38
|
+
|
39
|
+
def test_pair_by_line(self):
|
40
|
+
expected = [("Hello", "Hello"), ("World", "Python")]
|
41
|
+
result = pair_by_line(self.src, self.cmp)
|
42
|
+
self.assertEqual(result, expected)
|
43
|
+
|
44
|
+
def test_pair_by_line_comment(self):
|
45
|
+
# This test assumes that the source and comparison files have comments on the
|
46
|
+
# same lines
|
47
|
+
# You may need to adjust this test based on your specific use case
|
48
|
+
self.target = Path(__file__).parent / "target.py"
|
49
|
+
self.reference = Path(__file__).parent / "reference.py"
|
50
|
+
kwargs = {
|
51
|
+
"token_limit": 100,
|
52
|
+
"llm": None,
|
53
|
+
"lang": "python",
|
54
|
+
"target_file": self.target,
|
55
|
+
"reference_file": self.reference,
|
56
|
+
}
|
57
|
+
expected = [("# Hello\n", "# Hello\n")]
|
58
|
+
result = pair_by_line_comment(self.src, self.cmp, **kwargs)
|
59
|
+
self.assertEqual(result, expected)
|