janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,177 @@
1
+ import threading
2
+ from contextlib import contextmanager
3
+ from contextvars import ContextVar
4
+ from typing import Any, Generator
5
+
6
+ from langchain_core.callbacks import BaseCallbackHandler
7
+ from langchain_core.messages import AIMessage
8
+ from langchain_core.outputs import ChatGeneration, LLMResult
9
+ from langchain_core.tracers.context import register_configure_hook
10
+
11
+ from janus.utils.logger import create_logger
12
+
13
+ log = create_logger(__name__)
14
+
15
+
16
+ # Updated 2024-06-21
17
+ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
18
+ "gpt-3.5-turbo-0125": {"input": 0.0005, "output": 0.0015},
19
+ "gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
20
+ "gpt-4-0125-preview": {"input": 0.01, "output": 0.03},
21
+ "gpt-4-0613": {"input": 0.03, "output": 0.06},
22
+ "gpt-4o-2024-05-13": {"input": 0.005, "output": 0.015},
23
+ "anthropic.claude-v2": {"input": 0.008, "output": 0.024},
24
+ "anthropic.claude-instant-v1": {"input": 0.0008, "output": 0.0024},
25
+ "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125},
26
+ "anthropic.claude-3-sonnet-20240229-v1:0": {"input": 0.003, "output": 0.015},
27
+ "meta.llama2-13b-chat-v1": {"input": 0.00075, "output": 0.001},
28
+ "meta.llama2-70b-chat-v1": {"input": 0.00195, "output": 0.00256},
29
+ "meta.llama2-13b-v1": {"input": 0.0, "output": 0.0},
30
+ "meta.llama2-70b-v1": {"input": 0.00265, "output": 0.0035},
31
+ "meta.llama3-8b-instruct-v1:0": {"input": 0.0003, "output": 0.0006},
32
+ "meta.llama3-70b-instruct-v1:0": {"input": 0.00265, "output": 0.0035},
33
+ "amazon.titan-text-lite-v1": {"input": 0.00015, "output": 0.0002},
34
+ "amazon.titan-text-express-v1": {"input": 0.0002, "output": 0.0006},
35
+ "ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
36
+ "ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
37
+ "cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
38
+ }
39
+
40
+
41
+ def _get_token_cost(
42
+ prompt_tokens: int, completion_tokens: int, model_id: str | None
43
+ ) -> float:
44
+ """Get the cost of tokens according to model ID"""
45
+ if model_id not in COST_PER_1K_TOKENS:
46
+ raise ValueError(
47
+ f"Unknown model: {model_id}. Please provide a valid model name."
48
+ f" Known models are: {', '.join(COST_PER_1K_TOKENS.keys())}"
49
+ )
50
+ model_cost = COST_PER_1K_TOKENS[model_id]
51
+ input_cost = (prompt_tokens / 1000.0) * model_cost["input"]
52
+ output_cost = (completion_tokens / 1000.0) * model_cost["output"]
53
+ return input_cost + output_cost
54
+
55
+
56
+ class TokenUsageCallbackHandler(BaseCallbackHandler):
57
+ """Callback Handler that tracks metadata on model cost, retries, etc.
58
+ Based on https://github.com/langchain-ai/langchain/blob/master/libs
59
+ /community/langchain_community/callbacks/openai_info.py
60
+ """
61
+
62
+ total_tokens: int = 0
63
+ prompt_tokens: int = 0
64
+ completion_tokens: int = 0
65
+ successful_requests: int = 0
66
+ total_cost: float = 0.0
67
+
68
+ def __init__(self) -> None:
69
+ super().__init__()
70
+ self._lock = threading.Lock()
71
+
72
+ def __repr__(self) -> str:
73
+ return (
74
+ f"Tokens Used: {self.total_tokens}\n"
75
+ f"\tPrompt Tokens: {self.prompt_tokens}\n"
76
+ f"\tCompletion Tokens: {self.completion_tokens}\n"
77
+ f"Successful Requests: {self.successful_requests}\n"
78
+ f"Total Cost (USD): ${self.total_cost}"
79
+ )
80
+
81
+ @property
82
+ def always_verbose(self) -> bool:
83
+ """Whether to call verbose callbacks even if verbose is False."""
84
+ return True
85
+
86
+ def on_chat_model_start(self, *args, **kwargs):
87
+ pass
88
+
89
+ def on_llm_start(
90
+ self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
91
+ ) -> None:
92
+ """Print out the prompts."""
93
+ pass
94
+
95
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
96
+ """Print out the token."""
97
+ pass
98
+
99
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
100
+ """Collect token usage."""
101
+ # Check for usage_metadata (langchain-core >= 0.2.2)
102
+ try:
103
+ generation = response.generations[0][0]
104
+ except IndexError:
105
+ generation = None
106
+ if isinstance(generation, ChatGeneration):
107
+ try:
108
+ message = generation.message
109
+ if isinstance(message, AIMessage):
110
+ usage_metadata = message.usage_metadata
111
+ else:
112
+ usage_metadata = None
113
+ except AttributeError:
114
+ usage_metadata = None
115
+ else:
116
+ usage_metadata = None
117
+ if usage_metadata:
118
+ token_usage = {"total_tokens": usage_metadata["total_tokens"]}
119
+ completion_tokens = usage_metadata["output_tokens"]
120
+ prompt_tokens = usage_metadata["input_tokens"]
121
+ if response.llm_output is None:
122
+ # model name (and therefore cost) is unavailable in
123
+ # streaming responses
124
+ model_name = ""
125
+ else:
126
+ model_name = response.llm_output.get("model_name", "")
127
+
128
+ else:
129
+ if response.llm_output is None:
130
+ return None
131
+
132
+ if "token_usage" not in response.llm_output:
133
+ with self._lock:
134
+ self.successful_requests += 1
135
+ return None
136
+
137
+ # compute tokens and cost for this request
138
+ token_usage = response.llm_output["token_usage"]
139
+ completion_tokens = token_usage.get("completion_tokens", 0)
140
+ prompt_tokens = token_usage.get("prompt_tokens", 0)
141
+ model_name = response.llm_output.get("model_name", "")
142
+
143
+ total_cost = _get_token_cost(
144
+ prompt_tokens=prompt_tokens,
145
+ completion_tokens=completion_tokens,
146
+ model_id=model_name,
147
+ )
148
+
149
+ # update shared state behind lock
150
+ with self._lock:
151
+ self.total_cost += total_cost
152
+ self.total_tokens += token_usage.get("total_tokens", 0)
153
+ self.prompt_tokens += prompt_tokens
154
+ self.completion_tokens += completion_tokens
155
+ self.successful_requests += 1
156
+
157
+ def __copy__(self) -> "TokenUsageCallbackHandler":
158
+ """Return a copy of the callback handler."""
159
+ return self
160
+
161
+ def __deepcopy__(self, memo: Any) -> "TokenUsageCallbackHandler":
162
+ """Return a deep copy of the callback handler."""
163
+ return self
164
+
165
+
166
+ token_usage_callback_var: ContextVar[TokenUsageCallbackHandler | None] = ContextVar(
167
+ "token_usage_callback_var", default=None
168
+ )
169
+ register_configure_hook(token_usage_callback_var, True)
170
+
171
+
172
+ @contextmanager
173
+ def get_model_callback() -> Generator[TokenUsageCallbackHandler, None, None]:
174
+ cb = TokenUsageCallbackHandler()
175
+ token_usage_callback_var.set(cb)
176
+ yield cb
177
+ token_usage_callback_var.set(None)
janus/llm/models_info.py CHANGED
@@ -1,115 +1,208 @@
1
1
  import json
2
2
  import os
3
3
  from pathlib import Path
4
- from typing import Any, Dict, Tuple
4
+ from typing import Any, Callable
5
5
 
6
6
  from dotenv import load_dotenv
7
- from langchain.chat_models import ChatOpenAI
8
- from langchain.llms import HuggingFaceTextGenInference
9
- from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
7
+ from langchain_community.llms import HuggingFaceTextGenInference
8
+ from langchain_core.language_models import BaseLanguageModel
9
+ from langchain_openai import ChatOpenAI
10
+
11
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
+ from janus.prompts.prompt import (
13
+ ChatGptPromptEngine,
14
+ ClaudePromptEngine,
15
+ CoherePromptEngine,
16
+ Llama2PromptEngine,
17
+ Llama3PromptEngine,
18
+ PromptEngine,
19
+ TitanPromptEngine,
20
+ )
21
+
22
+ from ..utils.logger import create_logger
23
+
24
+ log = create_logger(__name__)
25
+
26
+ try:
27
+ from langchain_community.chat_models import BedrockChat
28
+ from langchain_community.llms.bedrock import Bedrock
29
+ except ImportError:
30
+ log.warning(
31
+ "Could not import LangChain's Bedrock Client. If you would like to use Bedrock "
32
+ "models, please install LangChain's Bedrock Client by running 'pip install "
33
+ "janus-llm[bedrock]' or poetry install -E bedrock."
34
+ )
35
+
36
+ try:
37
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
38
+ except ImportError:
39
+ log.warning(
40
+ "Could not import LangChain's HuggingFace Pipeline Client. If you would like to "
41
+ "use HuggingFace models, please install LangChain's HuggingFace Pipeline Client "
42
+ "by running 'pip install janus-llm[hf-local]' or poetry install -E hf-local."
43
+ )
44
+
10
45
 
11
46
  load_dotenv()
12
47
 
13
- MODEL_TYPE_CONSTRUCTORS = {
48
+ openai_model_reroutes = {
49
+ "gpt-4o": "gpt-4o-2024-05-13",
50
+ "gpt-4": "gpt-4-0613",
51
+ "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
52
+ "gpt-4-turbo-preview": "gpt-4-0125-preview",
53
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
54
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
55
+ }
56
+
57
+ openai_models = [
58
+ "gpt-4-0613",
59
+ "gpt-4-1106-preview",
60
+ "gpt-4-0125-preview",
61
+ "gpt-4o-2024-05-13",
62
+ "gpt-3.5-turbo-0125",
63
+ ]
64
+ claude_models = [
65
+ "bedrock-claude-v2",
66
+ "bedrock-claude-instant-v1",
67
+ "bedrock-claude-haiku",
68
+ "bedrock-claude-sonnet",
69
+ ]
70
+ llama2_models = [
71
+ "bedrock-llama2-70b",
72
+ "bedrock-llama2-70b-chat",
73
+ "bedrock-llama2-13b",
74
+ "bedrock-llama2-13b-chat",
75
+ ]
76
+ llama3_models = [
77
+ "bedrock-llama3-8b-instruct",
78
+ "bedrock-llama3-70b-instruct",
79
+ ]
80
+ titan_models = [
81
+ "bedrock-titan-text-lite",
82
+ "bedrock-titan-text-express",
83
+ "bedrock-jurassic-2-mid",
84
+ "bedrock-jurassic-2-ultra",
85
+ ]
86
+ cohere_models = [
87
+ "bedrock-command-r-plus",
88
+ ]
89
+ bedrock_models = [
90
+ *claude_models,
91
+ *llama2_models,
92
+ *llama3_models,
93
+ *titan_models,
94
+ *cohere_models,
95
+ ]
96
+ all_models = [*openai_models, *bedrock_models]
97
+
98
+ MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
14
99
  "OpenAI": ChatOpenAI,
15
100
  "HuggingFace": HuggingFaceTextGenInference,
16
- "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
17
101
  }
18
102
 
103
+ try:
104
+ MODEL_TYPE_CONSTRUCTORS.update(
105
+ {
106
+ "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
107
+ "Bedrock": Bedrock,
108
+ "BedrockChat": BedrockChat,
109
+ }
110
+ )
111
+ except NameError:
112
+ pass
113
+
19
114
 
20
- MODEL_TYPES: Dict[str, Any] = {
21
- "gpt-4": "OpenAI",
22
- "gpt-4-32k": "OpenAI",
23
- "gpt-4-1106-preview": "OpenAI",
24
- "gpt-3.5-turbo": "OpenAI",
25
- "gpt-3.5-turbo-16k": "OpenAI",
26
- "mitre-llama": "HuggingFace",
27
- "mitre-falcon": "HuggingFace",
28
- "mitre-wizard-coder": "HuggingFace",
115
+ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
116
+ **{m: ChatGptPromptEngine for m in openai_models},
117
+ **{m: ClaudePromptEngine for m in claude_models},
118
+ **{m: Llama2PromptEngine for m in llama2_models},
119
+ **{m: Llama3PromptEngine for m in llama3_models},
120
+ **{m: TitanPromptEngine for m in titan_models},
121
+ **{m: CoherePromptEngine for m in cohere_models},
29
122
  }
30
123
 
31
- _open_ai_defaults: Dict[str, Any] = {
124
+ _open_ai_defaults: dict[str, str] = {
32
125
  "openai_api_key": os.getenv("OPENAI_API_KEY"),
33
126
  "openai_organization": os.getenv("OPENAI_ORG_ID"),
34
127
  }
35
128
 
36
- MODEL_DEFAULT_ARGUMENTS: Dict[str, Dict[str, Any]] = {
37
- "gpt-4": dict(model_name="gpt-4"),
38
- "gpt-4-32k": dict(model_name="gpt-4-32k"),
39
- "gpt-4-1106-preview": dict(model_name="gpt-4-1106-preview"),
40
- "gpt-3.5-turbo": dict(model_name="gpt-3.5-turbo"),
41
- "gpt-3.5-turbo-16k": dict(model_name="gpt-3.5-turbo-16k"),
42
- "mitre-llama": dict(
43
- inference_server_url="https://llama2-70b.aip.mitre.org",
44
- max_new_tokens=4096,
45
- top_k=10,
46
- top_p=0.95,
47
- typical_p=0.95,
48
- temperature=0.01,
49
- repetition_penalty=1.03,
50
- timeout=240,
51
- ),
52
- "mitre-falcon": dict(
53
- inference_server_url="https://falcon-40b.aip.mitre.org",
54
- max_new_tokens=4096,
55
- top_k=10,
56
- top_p=0.95,
57
- typical_p=0.95,
58
- temperature=0.01,
59
- repetition_penalty=1.03,
60
- timeout=240,
61
- ),
62
- "mitre-wizard-coder": dict(
63
- inference_server_url="https://wizard-coder-34b.aip.mitre.org",
64
- max_new_tokens=1024,
65
- top_k=10,
66
- top_p=0.95,
67
- typical_p=0.95,
68
- temperature=0.01,
69
- repetition_penalty=1.03,
70
- timeout=240,
71
- ),
129
+ model_identifiers = {
130
+ **{m: m for m in openai_models},
131
+ "bedrock-claude-v2": "anthropic.claude-v2",
132
+ "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
133
+ "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
134
+ "bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
135
+ "bedrock-llama2-70b": "meta.llama2-70b-v1",
136
+ "bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
137
+ "bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
138
+ "bedrock-llama2-13b-chat": "meta.llama2-13b-v1",
139
+ "bedrock-llama3-8b-instruct": "meta.llama3-8b-instruct-v1:0",
140
+ "bedrock-llama3-70b-instruct": "meta.llama3-70b-instruct-v1:0",
141
+ "bedrock-titan-text-lite": "amazon.titan-text-lite-v1",
142
+ "bedrock-titan-text-express": "amazon.titan-text-express-v1",
143
+ "bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
144
+ "bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
145
+ "bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
146
+ }
147
+
148
+ MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
149
+ k: (dict(model_name=k) if k in openai_models else dict(model_id=v))
150
+ for k, v in model_identifiers.items()
72
151
  }
73
152
 
74
153
  DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
75
154
 
76
155
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
77
156
 
78
- TOKEN_LIMITS: Dict[str, int] = {
79
- "gpt-4": 8192,
157
+ MODEL_TYPES: dict[str, PromptEngine] = {
158
+ **{m: "OpenAI" for m in openai_models},
159
+ **{m: "BedrockChat" for m in bedrock_models},
160
+ }
161
+
162
+ TOKEN_LIMITS: dict[str, int] = {
80
163
  "gpt-4-32k": 32_768,
164
+ "gpt-4-0613": 8192,
81
165
  "gpt-4-1106-preview": 128_000,
82
- "gpt-3.5-turbo": 4096,
83
- "gpt-3.5-turbo-16k": 16_384,
84
- "mitre-falcon": 32_000,
166
+ "gpt-4-0125-preview": 128_000,
167
+ "gpt-4o-2024-05-13": 128_000,
168
+ "gpt-3.5-turbo-0125": 16_384,
85
169
  "text-embedding-ada-002": 8191,
86
170
  "gpt4all": 16_384,
87
- }
88
-
89
- COST_PER_MODEL: Dict[str, Dict[str, float]] = {
90
- "gpt-4": {"input": 0.03, "output": 0.06},
91
- "gpt-4-32k": {"input": 0.6, "output": 0.12},
92
- "gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
93
- "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
94
- "gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
95
- "mitre-llama": {"input": 0.0, "output": 0.0},
96
- "mitre-falcon": {"input": 0.0, "output": 0.0},
97
- "mitre-wizard-coder": {"input": 0.0, "output": 0.0},
171
+ "anthropic.claude-v2": 100_000,
172
+ "anthropic.claude-instant-v1": 100_000,
173
+ "anthropic.claude-3-haiku-20240307-v1:0": 248_000,
174
+ "anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
175
+ "meta.llama2-70b-v1": 4096,
176
+ "meta.llama2-70b-chat-v1": 4096,
177
+ "meta.llama2-13b-chat-v1": 4096,
178
+ "meta.llama2-13b-v1": 4096,
179
+ "meta.llama3-8b-instruct-v1:0": 8000,
180
+ "meta.llama3-70b-instruct-v1:0": 8000,
181
+ "amazon.titan-text-lite-v1": 4096,
182
+ "amazon.titan-text-express-v1": 8192,
183
+ "ai21.j2-mid-v1": 8192,
184
+ "ai21.j2-ultra-v1": 8192,
185
+ "cohere.command-r-plus-v1:0": 128_000,
98
186
  }
99
187
 
100
188
 
101
- def load_model(model_name: str) -> Tuple[Any, int, Dict[str, float]]:
189
+ def load_model(model_name: str) -> tuple[BaseLanguageModel, int, dict[str, float]]:
102
190
  if not MODEL_CONFIG_DIR.exists():
103
191
  MODEL_CONFIG_DIR.mkdir(parents=True)
104
192
  model_config_file = MODEL_CONFIG_DIR / f"{model_name}.json"
105
193
  if not model_config_file.exists():
106
194
  if model_name not in DEFAULT_MODELS:
107
- raise ValueError(f"Error: could not find model {model_name}")
195
+ if model_name in openai_model_reroutes:
196
+ model_name = openai_model_reroutes[model_name]
197
+ else:
198
+ raise ValueError(f"Error: could not find model {model_name}")
108
199
  model_config = {
109
200
  "model_type": MODEL_TYPES[model_name],
110
201
  "model_args": MODEL_DEFAULT_ARGUMENTS[model_name],
111
- "token_limit": TOKEN_LIMITS.get(model_name, 4096),
112
- "model_cost": COST_PER_MODEL.get(model_name, {"input": 0, "output": 0}),
202
+ "token_limit": TOKEN_LIMITS.get(model_identifiers[model_name], 4096),
203
+ "model_cost": COST_PER_1K_TOKENS.get(
204
+ model_identifiers[model_name], {"input": 0, "output": 0}
205
+ ),
113
206
  }
114
207
  with open(model_config_file, "w") as f:
115
208
  json.dump(model_config, f)
@@ -0,0 +1,8 @@
1
+ import glob
2
+ import os.path
3
+
4
+ modules = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
5
+ __all__ = []
6
+ for m in modules:
7
+ if os.path.isfile(m) and not os.path.samefile(m, __file__):
8
+ __all__.append(os.path.basename(m)[:-3])
File without changes
@@ -0,0 +1,2 @@
1
+ # Hello
2
+ pass
@@ -0,0 +1,2 @@
1
+ # Hello
2
+ pass
@@ -0,0 +1,56 @@
1
+ import unittest
2
+
3
+ from sacrebleu import sentence_bleu
4
+
5
+ from ..bleu import bleu
6
+
7
+
8
+ class TestBLEU(unittest.TestCase):
9
+ def setUp(self):
10
+ self.target_text = "This is a source text."
11
+ self.reference_text = "This is a destination text."
12
+
13
+ def test_bleu(self):
14
+ """Test the BLEU score calculation."""
15
+ function_score = (
16
+ sentence_bleu(
17
+ self.target_text,
18
+ [self.reference_text],
19
+ ).score
20
+ / 100.0
21
+ )
22
+ expected_score = bleu(
23
+ self.target_text,
24
+ self.reference_text,
25
+ )
26
+ self.assertEqual(function_score, expected_score)
27
+
28
+ def test_bleu_with_s_flag(self):
29
+ """Test the BLEU score calculation with the -S flag."""
30
+ function_score = (
31
+ sentence_bleu(
32
+ self.target_text,
33
+ [self.reference_text],
34
+ ).score
35
+ / 100.0
36
+ )
37
+ score_with_s_flag = bleu(
38
+ self.target_text,
39
+ self.reference_text,
40
+ use_strings=True, # Mimics -S
41
+ )
42
+ self.assertEqual(function_score, score_with_s_flag)
43
+
44
+ def test_bleu_invalid_target_type(self):
45
+ """Test the BLEU score calculation with invalid source text type."""
46
+ with self.assertRaises(TypeError):
47
+ sentence_bleu(123, [self.reference_text])
48
+
49
+ def test_bleu_invalid_reference_type(self):
50
+ """Test the BLEU score calculation with invalid destination text type."""
51
+ with self.assertRaises(TypeError):
52
+ sentence_bleu(self.target_text, 123)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ unittest.main()
@@ -0,0 +1,67 @@
1
+ import unittest
2
+
3
+ from sacrebleu import sentence_chrf
4
+
5
+ from ..chrf import chrf
6
+
7
+
8
+ class TestChrF(unittest.TestCase):
9
+ def setUp(self):
10
+ self.target_text = "This is a source text."
11
+ self.reference_text = "This is a destination text."
12
+ self.char_order = 6
13
+ self.word_order = 2
14
+ self.beta = 2.0
15
+
16
+ def test_chrf_custom_params(self):
17
+ """Test the chrf function with custom parameters."""
18
+ function_score = chrf(
19
+ self.target_text,
20
+ self.reference_text,
21
+ self.char_order,
22
+ self.word_order,
23
+ self.beta,
24
+ )
25
+ score = sentence_chrf(
26
+ hypothesis=self.target_text,
27
+ references=[self.reference_text],
28
+ char_order=self.char_order,
29
+ word_order=self.word_order,
30
+ beta=self.beta,
31
+ )
32
+ expected_score = float(score.score) / 100.0
33
+ self.assertEqual(function_score, expected_score)
34
+
35
+ def test_chrf_with_s_flag(self):
36
+ """Test the CHRF score calculation with the -S flag."""
37
+ function_score = sentence_chrf(
38
+ hypothesis=self.target_text,
39
+ references=[self.reference_text],
40
+ char_order=self.char_order,
41
+ word_order=self.word_order,
42
+ beta=self.beta,
43
+ )
44
+ function_score = float(function_score.score) / 100.0
45
+ score_with_s_flag = chrf(
46
+ self.target_text,
47
+ self.reference_text,
48
+ self.char_order,
49
+ self.word_order,
50
+ self.beta,
51
+ use_strings=True, # Mimics -S
52
+ )
53
+ self.assertEqual(function_score, score_with_s_flag)
54
+
55
+ def test_chrf_invalid_target_type(self):
56
+ """Test the chrf function with invalid source text type."""
57
+ with self.assertRaises(TypeError):
58
+ chrf(123, self.reference_text, self.char_order, self.word_order, self.beta)
59
+
60
+ def test_chrf_invalid_reference_type(self):
61
+ """Test the chrf function with invalid destination text type."""
62
+ with self.assertRaises(TypeError):
63
+ chrf(self.target_text, 123, self.char_order, self.word_order, self.beta)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ unittest.main()
@@ -0,0 +1,59 @@
1
+ # FILEPATH: /Users/mdoyle/projects/janus/janus/metrics/tests/test_file_pairing.py
2
+
3
+ import unittest
4
+ from pathlib import Path
5
+
6
+ from ..file_pairing import (
7
+ FILE_PAIRING_METHODS,
8
+ pair_by_file,
9
+ pair_by_line,
10
+ pair_by_line_comment,
11
+ register_pairing_method,
12
+ )
13
+
14
+
15
+ class TestFilePairing(unittest.TestCase):
16
+ def setUp(self):
17
+ self.src = "Hello\nWorld"
18
+ self.cmp = "Hello\nPython"
19
+ self.state = {
20
+ "token_limit": 100,
21
+ "llm": None,
22
+ "lang": "python",
23
+ "target_file": self.src,
24
+ "cmp_file": self.cmp,
25
+ }
26
+
27
+ def test_register_pairing_method(self):
28
+ @register_pairing_method(name="test")
29
+ def test_method(src, cmp, state):
30
+ return [(src, cmp)]
31
+
32
+ self.assertIn("test", FILE_PAIRING_METHODS)
33
+
34
+ def test_pair_by_file(self):
35
+ expected = [(self.src, self.cmp)]
36
+ result = pair_by_file(self.src, self.cmp)
37
+ self.assertEqual(result, expected)
38
+
39
+ def test_pair_by_line(self):
40
+ expected = [("Hello", "Hello"), ("World", "Python")]
41
+ result = pair_by_line(self.src, self.cmp)
42
+ self.assertEqual(result, expected)
43
+
44
+ def test_pair_by_line_comment(self):
45
+ # This test assumes that the source and comparison files have comments on the
46
+ # same lines
47
+ # You may need to adjust this test based on your specific use case
48
+ self.target = Path(__file__).parent / "target.py"
49
+ self.reference = Path(__file__).parent / "reference.py"
50
+ kwargs = {
51
+ "token_limit": 100,
52
+ "llm": None,
53
+ "lang": "python",
54
+ "target_file": self.target,
55
+ "reference_file": self.reference,
56
+ }
57
+ expected = [("# Hello\n", "# Hello\n")]
58
+ result = pair_by_line_comment(self.src, self.cmp, **kwargs)
59
+ self.assertEqual(result, expected)