janus-llm 1.0.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +120 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +9 -6
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +134 -70
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/METADATA +23 -10
  70. janus_llm-2.0.0.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,177 @@
1
+ import threading
2
+ from contextlib import contextmanager
3
+ from contextvars import ContextVar
4
+ from typing import Any, Generator
5
+
6
+ from langchain_core.callbacks import BaseCallbackHandler
7
+ from langchain_core.messages import AIMessage
8
+ from langchain_core.outputs import ChatGeneration, LLMResult
9
+ from langchain_core.tracers.context import register_configure_hook
10
+
11
+ from janus.utils.logger import create_logger
12
+
13
+ log = create_logger(__name__)
14
+
15
+
16
+ # Updated 2024-06-21
17
+ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
18
+ "gpt-3.5-turbo-0125": {"input": 0.0005, "output": 0.0015},
19
+ "gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
20
+ "gpt-4-0125-preview": {"input": 0.01, "output": 0.03},
21
+ "gpt-4-0613": {"input": 0.03, "output": 0.06},
22
+ "gpt-4o-2024-05-13": {"input": 0.005, "output": 0.015},
23
+ "anthropic.claude-v2": {"input": 0.008, "output": 0.024},
24
+ "anthropic.claude-instant-v1": {"input": 0.0008, "output": 0.0024},
25
+ "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125},
26
+ "anthropic.claude-3-sonnet-20240229-v1:0": {"input": 0.003, "output": 0.015},
27
+ "meta.llama2-13b-chat-v1": {"input": 0.00075, "output": 0.001},
28
+ "meta.llama2-70b-chat-v1": {"input": 0.00195, "output": 0.00256},
29
+ "meta.llama2-13b-v1": {"input": 0.0, "output": 0.0},
30
+ "meta.llama2-70b-v1": {"input": 0.00265, "output": 0.0035},
31
+ "meta.llama3-8b-instruct-v1:0": {"input": 0.0003, "output": 0.0006},
32
+ "meta.llama3-70b-instruct-v1:0": {"input": 0.00265, "output": 0.0035},
33
+ "amazon.titan-text-lite-v1": {"input": 0.00015, "output": 0.0002},
34
+ "amazon.titan-text-express-v1": {"input": 0.0002, "output": 0.0006},
35
+ "ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
36
+ "ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
37
+ "cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
38
+ }
39
+
40
+
41
+ def _get_token_cost(
42
+ prompt_tokens: int, completion_tokens: int, model_id: str | None
43
+ ) -> float:
44
+ """Get the cost of tokens according to model ID"""
45
+ if model_id not in COST_PER_1K_TOKENS:
46
+ raise ValueError(
47
+ f"Unknown model: {model_id}. Please provide a valid model name."
48
+ f" Known models are: {', '.join(COST_PER_1K_TOKENS.keys())}"
49
+ )
50
+ model_cost = COST_PER_1K_TOKENS[model_id]
51
+ input_cost = (prompt_tokens / 1000.0) * model_cost["input"]
52
+ output_cost = (completion_tokens / 1000.0) * model_cost["output"]
53
+ return input_cost + output_cost
54
+
55
+
56
+ class TokenUsageCallbackHandler(BaseCallbackHandler):
57
+ """Callback Handler that tracks metadata on model cost, retries, etc.
58
+ Based on https://github.com/langchain-ai/langchain/blob/master/libs
59
+ /community/langchain_community/callbacks/openai_info.py
60
+ """
61
+
62
+ total_tokens: int = 0
63
+ prompt_tokens: int = 0
64
+ completion_tokens: int = 0
65
+ successful_requests: int = 0
66
+ total_cost: float = 0.0
67
+
68
+ def __init__(self) -> None:
69
+ super().__init__()
70
+ self._lock = threading.Lock()
71
+
72
+ def __repr__(self) -> str:
73
+ return (
74
+ f"Tokens Used: {self.total_tokens}\n"
75
+ f"\tPrompt Tokens: {self.prompt_tokens}\n"
76
+ f"\tCompletion Tokens: {self.completion_tokens}\n"
77
+ f"Successful Requests: {self.successful_requests}\n"
78
+ f"Total Cost (USD): ${self.total_cost}"
79
+ )
80
+
81
+ @property
82
+ def always_verbose(self) -> bool:
83
+ """Whether to call verbose callbacks even if verbose is False."""
84
+ return True
85
+
86
+ def on_chat_model_start(self, *args, **kwargs):
87
+ pass
88
+
89
+ def on_llm_start(
90
+ self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
91
+ ) -> None:
92
+ """Print out the prompts."""
93
+ pass
94
+
95
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
96
+ """Print out the token."""
97
+ pass
98
+
99
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
100
+ """Collect token usage."""
101
+ # Check for usage_metadata (langchain-core >= 0.2.2)
102
+ try:
103
+ generation = response.generations[0][0]
104
+ except IndexError:
105
+ generation = None
106
+ if isinstance(generation, ChatGeneration):
107
+ try:
108
+ message = generation.message
109
+ if isinstance(message, AIMessage):
110
+ usage_metadata = message.usage_metadata
111
+ else:
112
+ usage_metadata = None
113
+ except AttributeError:
114
+ usage_metadata = None
115
+ else:
116
+ usage_metadata = None
117
+ if usage_metadata:
118
+ token_usage = {"total_tokens": usage_metadata["total_tokens"]}
119
+ completion_tokens = usage_metadata["output_tokens"]
120
+ prompt_tokens = usage_metadata["input_tokens"]
121
+ if response.llm_output is None:
122
+ # model name (and therefore cost) is unavailable in
123
+ # streaming responses
124
+ model_name = ""
125
+ else:
126
+ model_name = response.llm_output.get("model_name", "")
127
+
128
+ else:
129
+ if response.llm_output is None:
130
+ return None
131
+
132
+ if "token_usage" not in response.llm_output:
133
+ with self._lock:
134
+ self.successful_requests += 1
135
+ return None
136
+
137
+ # compute tokens and cost for this request
138
+ token_usage = response.llm_output["token_usage"]
139
+ completion_tokens = token_usage.get("completion_tokens", 0)
140
+ prompt_tokens = token_usage.get("prompt_tokens", 0)
141
+ model_name = response.llm_output.get("model_name", "")
142
+
143
+ total_cost = _get_token_cost(
144
+ prompt_tokens=prompt_tokens,
145
+ completion_tokens=completion_tokens,
146
+ model_id=model_name,
147
+ )
148
+
149
+ # update shared state behind lock
150
+ with self._lock:
151
+ self.total_cost += total_cost
152
+ self.total_tokens += token_usage.get("total_tokens", 0)
153
+ self.prompt_tokens += prompt_tokens
154
+ self.completion_tokens += completion_tokens
155
+ self.successful_requests += 1
156
+
157
+ def __copy__(self) -> "TokenUsageCallbackHandler":
158
+ """Return a copy of the callback handler."""
159
+ return self
160
+
161
+ def __deepcopy__(self, memo: Any) -> "TokenUsageCallbackHandler":
162
+ """Return a deep copy of the callback handler."""
163
+ return self
164
+
165
+
166
+ token_usage_callback_var: ContextVar[TokenUsageCallbackHandler | None] = ContextVar(
167
+ "token_usage_callback_var", default=None
168
+ )
169
+ register_configure_hook(token_usage_callback_var, True)
170
+
171
+
172
+ @contextmanager
173
+ def get_model_callback() -> Generator[TokenUsageCallbackHandler, None, None]:
174
+ cb = TokenUsageCallbackHandler()
175
+ token_usage_callback_var.set(cb)
176
+ yield cb
177
+ token_usage_callback_var.set(None)
janus/llm/models_info.py CHANGED
@@ -1,115 +1,179 @@
1
1
  import json
2
2
  import os
3
3
  from pathlib import Path
4
- from typing import Any, Dict, Tuple
4
+ from typing import Any, Callable
5
5
 
6
6
  from dotenv import load_dotenv
7
- from langchain.chat_models import ChatOpenAI
8
- from langchain.llms import HuggingFaceTextGenInference
7
+ from langchain_community.chat_models import BedrockChat
8
+ from langchain_community.llms import HuggingFaceTextGenInference
9
+ from langchain_community.llms.bedrock import Bedrock
9
10
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
11
+ from langchain_core.language_models import BaseLanguageModel
12
+ from langchain_openai import ChatOpenAI
13
+
14
+ from janus.llm.model_callbacks import COST_PER_1K_TOKENS
15
+ from janus.prompts.prompt import (
16
+ ChatGptPromptEngine,
17
+ ClaudePromptEngine,
18
+ CoherePromptEngine,
19
+ Llama2PromptEngine,
20
+ Llama3PromptEngine,
21
+ PromptEngine,
22
+ TitanPromptEngine,
23
+ )
10
24
 
11
25
  load_dotenv()
12
26
 
13
- MODEL_TYPE_CONSTRUCTORS = {
27
+ openai_model_reroutes = {
28
+ "gpt-4o": "gpt-4o-2024-05-13",
29
+ "gpt-4": "gpt-4-0613",
30
+ "gpt-4-turbo": "gpt-4-turbo-2024-04-09",
31
+ "gpt-4-turbo-preview": "gpt-4-0125-preview",
32
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0125",
33
+ "gpt-3.5-turbo-16k": "gpt-3.5-turbo-0125",
34
+ }
35
+
36
+ openai_models = [
37
+ "gpt-4-0613",
38
+ "gpt-4-1106-preview",
39
+ "gpt-4-0125-preview",
40
+ "gpt-4o-2024-05-13",
41
+ "gpt-3.5-turbo-0125",
42
+ ]
43
+ claude_models = [
44
+ "bedrock-claude-v2",
45
+ "bedrock-claude-instant-v1",
46
+ "bedrock-claude-haiku",
47
+ "bedrock-claude-sonnet",
48
+ ]
49
+ llama2_models = [
50
+ "bedrock-llama2-70b",
51
+ "bedrock-llama2-70b-chat",
52
+ "bedrock-llama2-13b",
53
+ "bedrock-llama2-13b-chat",
54
+ ]
55
+ llama3_models = [
56
+ "bedrock-llama3-8b-instruct",
57
+ "bedrock-llama3-70b-instruct",
58
+ ]
59
+ titan_models = [
60
+ "bedrock-titan-text-lite",
61
+ "bedrock-titan-text-express",
62
+ "bedrock-jurassic-2-mid",
63
+ "bedrock-jurassic-2-ultra",
64
+ ]
65
+ cohere_models = [
66
+ "bedrock-command-r-plus",
67
+ ]
68
+ bedrock_models = [
69
+ *claude_models,
70
+ *llama2_models,
71
+ *llama3_models,
72
+ *titan_models,
73
+ *cohere_models,
74
+ ]
75
+ all_models = [*openai_models, *bedrock_models]
76
+
77
+ MODEL_TYPE_CONSTRUCTORS: dict[str, Callable[[Any], BaseLanguageModel]] = {
14
78
  "OpenAI": ChatOpenAI,
15
79
  "HuggingFace": HuggingFaceTextGenInference,
16
80
  "HuggingFaceLocal": HuggingFacePipeline.from_model_id,
81
+ "Bedrock": Bedrock,
82
+ "BedrockChat": BedrockChat,
17
83
  }
18
84
 
19
85
 
20
- MODEL_TYPES: Dict[str, Any] = {
21
- "gpt-4": "OpenAI",
22
- "gpt-4-32k": "OpenAI",
23
- "gpt-4-1106-preview": "OpenAI",
24
- "gpt-3.5-turbo": "OpenAI",
25
- "gpt-3.5-turbo-16k": "OpenAI",
26
- "mitre-llama": "HuggingFace",
27
- "mitre-falcon": "HuggingFace",
28
- "mitre-wizard-coder": "HuggingFace",
86
+ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
87
+ **{m: ChatGptPromptEngine for m in openai_models},
88
+ **{m: ClaudePromptEngine for m in claude_models},
89
+ **{m: Llama2PromptEngine for m in llama2_models},
90
+ **{m: Llama3PromptEngine for m in llama3_models},
91
+ **{m: TitanPromptEngine for m in titan_models},
92
+ **{m: CoherePromptEngine for m in cohere_models},
29
93
  }
30
94
 
31
- _open_ai_defaults: Dict[str, Any] = {
95
+ _open_ai_defaults: dict[str, str] = {
32
96
  "openai_api_key": os.getenv("OPENAI_API_KEY"),
33
97
  "openai_organization": os.getenv("OPENAI_ORG_ID"),
34
98
  }
35
99
 
36
- MODEL_DEFAULT_ARGUMENTS: Dict[str, Dict[str, Any]] = {
37
- "gpt-4": dict(model_name="gpt-4"),
38
- "gpt-4-32k": dict(model_name="gpt-4-32k"),
39
- "gpt-4-1106-preview": dict(model_name="gpt-4-1106-preview"),
40
- "gpt-3.5-turbo": dict(model_name="gpt-3.5-turbo"),
41
- "gpt-3.5-turbo-16k": dict(model_name="gpt-3.5-turbo-16k"),
42
- "mitre-llama": dict(
43
- inference_server_url="https://llama2-70b.aip.mitre.org",
44
- max_new_tokens=4096,
45
- top_k=10,
46
- top_p=0.95,
47
- typical_p=0.95,
48
- temperature=0.01,
49
- repetition_penalty=1.03,
50
- timeout=240,
51
- ),
52
- "mitre-falcon": dict(
53
- inference_server_url="https://falcon-40b.aip.mitre.org",
54
- max_new_tokens=4096,
55
- top_k=10,
56
- top_p=0.95,
57
- typical_p=0.95,
58
- temperature=0.01,
59
- repetition_penalty=1.03,
60
- timeout=240,
61
- ),
62
- "mitre-wizard-coder": dict(
63
- inference_server_url="https://wizard-coder-34b.aip.mitre.org",
64
- max_new_tokens=1024,
65
- top_k=10,
66
- top_p=0.95,
67
- typical_p=0.95,
68
- temperature=0.01,
69
- repetition_penalty=1.03,
70
- timeout=240,
71
- ),
100
+ model_identifiers = {
101
+ **{m: m for m in openai_models},
102
+ "bedrock-claude-v2": "anthropic.claude-v2",
103
+ "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
104
+ "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
105
+ "bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
106
+ "bedrock-llama2-70b": "meta.llama2-70b-v1",
107
+ "bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
108
+ "bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
109
+ "bedrock-llama2-13b-chat": "meta.llama2-13b-v1",
110
+ "bedrock-llama3-8b-instruct": "meta.llama3-8b-instruct-v1:0",
111
+ "bedrock-llama3-70b-instruct": "meta.llama3-70b-instruct-v1:0",
112
+ "bedrock-titan-text-lite": "amazon.titan-text-lite-v1",
113
+ "bedrock-titan-text-express": "amazon.titan-text-express-v1",
114
+ "bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
115
+ "bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
116
+ "bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
117
+ }
118
+
119
+ MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
120
+ k: (dict(model_name=k) if k in openai_models else dict(model_id=v))
121
+ for k, v in model_identifiers.items()
72
122
  }
73
123
 
74
124
  DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
75
125
 
76
126
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
77
127
 
78
- TOKEN_LIMITS: Dict[str, int] = {
79
- "gpt-4": 8192,
128
+ MODEL_TYPES: dict[str, PromptEngine] = {
129
+ **{model_identifiers[m]: "OpenAI" for m in openai_models},
130
+ **{model_identifiers[m]: "BedrockChat" for m in bedrock_models},
131
+ }
132
+
133
+ TOKEN_LIMITS: dict[str, int] = {
80
134
  "gpt-4-32k": 32_768,
135
+ "gpt-4-0613": 8192,
81
136
  "gpt-4-1106-preview": 128_000,
82
- "gpt-3.5-turbo": 4096,
83
- "gpt-3.5-turbo-16k": 16_384,
84
- "mitre-falcon": 32_000,
137
+ "gpt-4-0125-preview": 128_000,
138
+ "gpt-4o-2024-05-13": 128_000,
139
+ "gpt-3.5-turbo-0125": 16_384,
85
140
  "text-embedding-ada-002": 8191,
86
141
  "gpt4all": 16_384,
87
- }
88
-
89
- COST_PER_MODEL: Dict[str, Dict[str, float]] = {
90
- "gpt-4": {"input": 0.03, "output": 0.06},
91
- "gpt-4-32k": {"input": 0.6, "output": 0.12},
92
- "gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
93
- "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
94
- "gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
95
- "mitre-llama": {"input": 0.0, "output": 0.0},
96
- "mitre-falcon": {"input": 0.0, "output": 0.0},
97
- "mitre-wizard-coder": {"input": 0.0, "output": 0.0},
142
+ "anthropic.claude-v2": 100_000,
143
+ "anthropic.claude-instant-v1": 100_000,
144
+ "anthropic.claude-3-haiku-20240307-v1:0": 248_000,
145
+ "anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
146
+ "meta.llama2-70b-v1": 4096,
147
+ "meta.llama2-70b-chat-v1": 4096,
148
+ "meta.llama2-13b-chat-v1": 4096,
149
+ "meta.llama2-13b-v1": 4096,
150
+ "meta.llama3-8b-instruct-v1:0": 8000,
151
+ "meta.llama3-70b-instruct-v1:0": 8000,
152
+ "amazon.titan-text-lite-v1": 4096,
153
+ "amazon.titan-text-express-v1": 8192,
154
+ "ai21.j2-mid-v1": 8192,
155
+ "ai21.j2-ultra-v1": 8192,
156
+ "cohere.command-r-plus-v1:0": 128_000,
98
157
  }
99
158
 
100
159
 
101
- def load_model(model_name: str) -> Tuple[Any, int, Dict[str, float]]:
160
+ def load_model(model_name: str) -> tuple[BaseLanguageModel, int, dict[str, float]]:
102
161
  if not MODEL_CONFIG_DIR.exists():
103
162
  MODEL_CONFIG_DIR.mkdir(parents=True)
104
163
  model_config_file = MODEL_CONFIG_DIR / f"{model_name}.json"
105
164
  if not model_config_file.exists():
106
165
  if model_name not in DEFAULT_MODELS:
107
- raise ValueError(f"Error: could not find model {model_name}")
166
+ if model_name in openai_model_reroutes:
167
+ model_name = openai_model_reroutes[model_name]
168
+ else:
169
+ raise ValueError(f"Error: could not find model {model_name}")
108
170
  model_config = {
109
171
  "model_type": MODEL_TYPES[model_name],
110
172
  "model_args": MODEL_DEFAULT_ARGUMENTS[model_name],
111
- "token_limit": TOKEN_LIMITS.get(model_name, 4096),
112
- "model_cost": COST_PER_MODEL.get(model_name, {"input": 0, "output": 0}),
173
+ "token_limit": TOKEN_LIMITS.get(model_identifiers[model_name], 4096),
174
+ "model_cost": COST_PER_1K_TOKENS.get(
175
+ model_identifiers[model_name], {"input": 0, "output": 0}
176
+ ),
113
177
  }
114
178
  with open(model_config_file, "w") as f:
115
179
  json.dump(model_config, f)
@@ -0,0 +1,8 @@
1
+ import glob
2
+ import os.path
3
+
4
+ modules = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
5
+ __all__ = []
6
+ for m in modules:
7
+ if os.path.isfile(m) and not os.path.samefile(m, __file__):
8
+ __all__.append(os.path.basename(m)[:-3])
File without changes
@@ -0,0 +1,2 @@
1
+ # Hello
2
+ pass
@@ -0,0 +1,2 @@
1
+ # Hello
2
+ pass
@@ -0,0 +1,56 @@
1
+ import unittest
2
+
3
+ from sacrebleu import sentence_bleu
4
+
5
+ from ..bleu import bleu
6
+
7
+
8
+ class TestBLEU(unittest.TestCase):
9
+ def setUp(self):
10
+ self.target_text = "This is a source text."
11
+ self.reference_text = "This is a destination text."
12
+
13
+ def test_bleu(self):
14
+ """Test the BLEU score calculation."""
15
+ function_score = (
16
+ sentence_bleu(
17
+ self.target_text,
18
+ [self.reference_text],
19
+ ).score
20
+ / 100.0
21
+ )
22
+ expected_score = bleu(
23
+ self.target_text,
24
+ self.reference_text,
25
+ )
26
+ self.assertEqual(function_score, expected_score)
27
+
28
+ def test_bleu_with_s_flag(self):
29
+ """Test the BLEU score calculation with the -S flag."""
30
+ function_score = (
31
+ sentence_bleu(
32
+ self.target_text,
33
+ [self.reference_text],
34
+ ).score
35
+ / 100.0
36
+ )
37
+ score_with_s_flag = bleu(
38
+ self.target_text,
39
+ self.reference_text,
40
+ use_strings=True, # Mimics -S
41
+ )
42
+ self.assertEqual(function_score, score_with_s_flag)
43
+
44
+ def test_bleu_invalid_target_type(self):
45
+ """Test the BLEU score calculation with invalid source text type."""
46
+ with self.assertRaises(TypeError):
47
+ sentence_bleu(123, [self.reference_text])
48
+
49
+ def test_bleu_invalid_reference_type(self):
50
+ """Test the BLEU score calculation with invalid destination text type."""
51
+ with self.assertRaises(TypeError):
52
+ sentence_bleu(self.target_text, 123)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ unittest.main()
@@ -0,0 +1,67 @@
1
+ import unittest
2
+
3
+ from sacrebleu import sentence_chrf
4
+
5
+ from ..chrf import chrf
6
+
7
+
8
+ class TestChrF(unittest.TestCase):
9
+ def setUp(self):
10
+ self.target_text = "This is a source text."
11
+ self.reference_text = "This is a destination text."
12
+ self.char_order = 6
13
+ self.word_order = 2
14
+ self.beta = 2.0
15
+
16
+ def test_chrf_custom_params(self):
17
+ """Test the chrf function with custom parameters."""
18
+ function_score = chrf(
19
+ self.target_text,
20
+ self.reference_text,
21
+ self.char_order,
22
+ self.word_order,
23
+ self.beta,
24
+ )
25
+ score = sentence_chrf(
26
+ hypothesis=self.target_text,
27
+ references=[self.reference_text],
28
+ char_order=self.char_order,
29
+ word_order=self.word_order,
30
+ beta=self.beta,
31
+ )
32
+ expected_score = float(score.score) / 100.0
33
+ self.assertEqual(function_score, expected_score)
34
+
35
+ def test_chrf_with_s_flag(self):
36
+ """Test the CHRF score calculation with the -S flag."""
37
+ function_score = sentence_chrf(
38
+ hypothesis=self.target_text,
39
+ references=[self.reference_text],
40
+ char_order=self.char_order,
41
+ word_order=self.word_order,
42
+ beta=self.beta,
43
+ )
44
+ function_score = float(function_score.score) / 100.0
45
+ score_with_s_flag = chrf(
46
+ self.target_text,
47
+ self.reference_text,
48
+ self.char_order,
49
+ self.word_order,
50
+ self.beta,
51
+ use_strings=True, # Mimics -S
52
+ )
53
+ self.assertEqual(function_score, score_with_s_flag)
54
+
55
+ def test_chrf_invalid_target_type(self):
56
+ """Test the chrf function with invalid source text type."""
57
+ with self.assertRaises(TypeError):
58
+ chrf(123, self.reference_text, self.char_order, self.word_order, self.beta)
59
+
60
+ def test_chrf_invalid_reference_type(self):
61
+ """Test the chrf function with invalid destination text type."""
62
+ with self.assertRaises(TypeError):
63
+ chrf(self.target_text, 123, self.char_order, self.word_order, self.beta)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ unittest.main()
@@ -0,0 +1,59 @@
1
+ # FILEPATH: /Users/mdoyle/projects/janus/janus/metrics/tests/test_file_pairing.py
2
+
3
+ import unittest
4
+ from pathlib import Path
5
+
6
+ from ..file_pairing import (
7
+ FILE_PAIRING_METHODS,
8
+ pair_by_file,
9
+ pair_by_line,
10
+ pair_by_line_comment,
11
+ register_pairing_method,
12
+ )
13
+
14
+
15
+ class TestFilePairing(unittest.TestCase):
16
+ def setUp(self):
17
+ self.src = "Hello\nWorld"
18
+ self.cmp = "Hello\nPython"
19
+ self.state = {
20
+ "token_limit": 100,
21
+ "llm": None,
22
+ "lang": "python",
23
+ "target_file": self.src,
24
+ "cmp_file": self.cmp,
25
+ }
26
+
27
+ def test_register_pairing_method(self):
28
+ @register_pairing_method(name="test")
29
+ def test_method(src, cmp, state):
30
+ return [(src, cmp)]
31
+
32
+ self.assertIn("test", FILE_PAIRING_METHODS)
33
+
34
+ def test_pair_by_file(self):
35
+ expected = [(self.src, self.cmp)]
36
+ result = pair_by_file(self.src, self.cmp)
37
+ self.assertEqual(result, expected)
38
+
39
+ def test_pair_by_line(self):
40
+ expected = [("Hello", "Hello"), ("World", "Python")]
41
+ result = pair_by_line(self.src, self.cmp)
42
+ self.assertEqual(result, expected)
43
+
44
+ def test_pair_by_line_comment(self):
45
+ # This test assumes that the source and comparison files have comments on the
46
+ # same lines
47
+ # You may need to adjust this test based on your specific use case
48
+ self.target = Path(__file__).parent / "target.py"
49
+ self.reference = Path(__file__).parent / "reference.py"
50
+ kwargs = {
51
+ "token_limit": 100,
52
+ "llm": None,
53
+ "lang": "python",
54
+ "target_file": self.target,
55
+ "reference_file": self.reference,
56
+ }
57
+ expected = [("# Hello\n", "# Hello\n")]
58
+ result = pair_by_line_comment(self.src, self.cmp, **kwargs)
59
+ self.assertEqual(result, expected)