janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,72 @@
1
+ import json
2
+ import re
3
+
4
+ from langchain.output_parsers.json import parse_json_markdown
5
+ from langchain.schema.output_parser import BaseOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.messages import AIMessage
8
+
9
+ from ..language.block import CodeBlock
10
+ from ..utils.logger import create_logger
11
+ from .code_parser import JanusParser
12
+
13
+ log = create_logger(__name__)
14
+
15
+
16
+ class RequirementsParser(BaseOutputParser[str], JanusParser):
17
+ block_name: str = ""
18
+
19
+ def __init__(self):
20
+ super().__init__(expected_keys=[])
21
+
22
+ def set_reference(self, block: CodeBlock):
23
+ self.block_name = block.name
24
+
25
+ def parse(self, text: str) -> str:
26
+ if isinstance(text, AIMessage):
27
+ text = text.content
28
+ text = text.lstrip("```json")
29
+ text = text.rstrip("```")
30
+ try:
31
+ obj = parse_json_markdown(text)
32
+ except json.JSONDecodeError as e:
33
+ log.debug(f"Invalid JSON object. Output:\n{text}")
34
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
35
+
36
+ if not isinstance(obj, dict):
37
+ raise OutputParserException(
38
+ f"Got invalid return object. Expected a dictionary, but got {type(obj)}"
39
+ )
40
+ return json.dumps(obj)
41
+
42
+ def parse_combined_output(self, text: str):
43
+ """Parse the output text from the LLM when multiple inputs are combined.
44
+
45
+ Arguments:
46
+ text: The output text from the LLM.
47
+
48
+ Returns:
49
+ A parsed version of the text.
50
+ """
51
+ json_strings = re.findall(r"\{.*?\}", text)
52
+ output_list = list()
53
+ for i, json_string in enumerate(json_strings, 1):
54
+ json_dict = json.loads(json_string)
55
+ output_list.append(json_dict["requirements"])
56
+ return output_list
57
+
58
+ def get_format_instructions(self) -> str:
59
+ """Get the format instructions for the parser.
60
+
61
+ Returns:
62
+ The format instructions for the LLM.
63
+ """
64
+ return (
65
+ "Output must contain an ieee style requirements specification "
66
+ "all in a json-formatted string, including the following field: "
67
+ '"requirements".'
68
+ )
69
+
70
+ @property
71
+ def _type(self) -> str:
72
+ return self.__class__.name
janus/prompts/prompt.py CHANGED
@@ -1,15 +1,14 @@
1
1
  import json
2
+ from abc import ABC, abstractmethod
2
3
  from pathlib import Path
3
- from typing import List
4
4
 
5
+ from langchain import PromptTemplate
5
6
  from langchain.prompts import ChatPromptTemplate
6
7
  from langchain.prompts.chat import (
7
8
  HumanMessagePromptTemplate,
8
9
  SystemMessagePromptTemplate,
9
10
  )
10
- from langchain.schema.messages import BaseMessage
11
11
 
12
- from ..language.block import CodeBlock
13
12
  from ..utils.enums import LANGUAGES
14
13
  from ..utils.logger import create_logger
15
14
 
@@ -18,12 +17,13 @@ log = create_logger(__name__)
18
17
 
19
18
  # Prompt names (self.template_map keys) that should output text,
20
19
  # regardless of the `output-lang` argument.
21
- TEXT_OUTPUT = ["document", "requirements"]
20
+ TEXT_OUTPUT = []
21
+
22
22
  # Prompt names (self.template_map keys) that should output the
23
23
  # same language as the input, regardless of the `output-lang` argument.
24
24
  SAME_OUTPUT = ["document_inline"]
25
25
 
26
- JSON_OUTPUT = ["evaluate"]
26
+ JSON_OUTPUT = ["evaluate", "document", "document_madlibs", "requirements"]
27
27
 
28
28
  # Directory containing Janus prompt template directories and files
29
29
  JANUS_PROMPT_TEMPLATES_DIR = Path(__file__).parent / "templates"
@@ -34,7 +34,7 @@ HUMAN_PROMPT_TEMPLATE_FILENAME = "human.txt"
34
34
  PROMPT_VARIABLES_FILENAME = "variables.json"
35
35
 
36
36
 
37
- class PromptEngine:
37
+ class PromptEngine(ABC):
38
38
  """A class defining prompting schemes for the LLM."""
39
39
 
40
40
  def __init__(
@@ -59,22 +59,14 @@ class PromptEngine:
59
59
  template_path = self.get_prompt_template_path(prompt_template)
60
60
  self._template_path = template_path
61
61
  self._template_name = prompt_template
62
- system_prompt_path = SystemMessagePromptTemplate.from_template(
63
- (template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME).read_text()
64
- )
65
- human_prompt_path = HumanMessagePromptTemplate.from_template(
66
- (template_path / HUMAN_PROMPT_TEMPLATE_FILENAME).read_text()
67
- )
68
- self.prompt = ChatPromptTemplate.from_messages(
69
- [system_prompt_path, human_prompt_path]
70
- )
62
+ self.prompt = self.load_prompt_template(template_path)
71
63
 
72
64
  # Define variables to be passed in to the prompt formatter
73
65
  source_language = source_language.lower()
74
66
  target_language = target_language.lower()
75
67
  self.variables = dict(
76
- SOURCE_LANGUAGE=source_language.lower(),
77
- TARGET_LANGUAGE=target_language.lower(),
68
+ SOURCE_LANGUAGE=source_language,
69
+ TARGET_LANGUAGE=target_language,
78
70
  TARGET_LANGUAGE_VERSION=str(target_version),
79
71
  FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
80
72
  SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
@@ -83,20 +75,11 @@ class PromptEngine:
83
75
  variables_path = template_path / PROMPT_VARIABLES_FILENAME
84
76
  if variables_path.exists():
85
77
  self.variables.update(json.loads(variables_path.read_text()))
78
+ self.prompt = self.prompt.partial(**self.variables)
86
79
 
87
- def create(self, code: CodeBlock) -> List[BaseMessage]:
88
- """Convert a code block to a Chat GPT prompt.
89
-
90
- Arguments:
91
- code: The code block to convert.
92
-
93
- Returns:
94
- The converted prompt as a list of messages.
95
- """
96
- return self.prompt.format_prompt(
97
- SOURCE_CODE=code.text,
98
- **self.variables,
99
- ).to_messages()
80
+ @abstractmethod
81
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
82
+ pass
100
83
 
101
84
  @staticmethod
102
85
  def get_prompt_template_path(template_name: str) -> Path:
@@ -146,3 +129,93 @@ class PromptEngine:
146
129
  f"Specified prompt template directory {template_path} is "
147
130
  f"missing a {HUMAN_PROMPT_TEMPLATE_FILENAME}"
148
131
  )
132
+
133
+
134
+ class ChatGptPromptEngine(PromptEngine):
135
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
136
+ system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
137
+ system_prompt = system_prompt_path.read_text()
138
+ system_message = SystemMessagePromptTemplate.from_template(system_prompt)
139
+
140
+ human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
141
+ human_prompt = human_prompt_path.read_text()
142
+ human_message = HumanMessagePromptTemplate.from_template(human_prompt)
143
+ return ChatPromptTemplate.from_messages([system_message, human_message])
144
+
145
+
146
+ class ClaudePromptEngine(PromptEngine):
147
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
148
+ prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
149
+ prompt = prompt_path.read_text()
150
+ return PromptTemplate.from_template(f"Human: {prompt}\n\nAssistant: ")
151
+
152
+
153
+ class TitanPromptEngine(PromptEngine):
154
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
155
+ prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
156
+ prompt = prompt_path.read_text()
157
+ return PromptTemplate.from_template(f"User: {prompt}\n\nAssistant: ")
158
+
159
+
160
+ class Llama2PromptEngine(PromptEngine):
161
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
162
+ system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
163
+ system_prompt = system_prompt_path.read_text()
164
+
165
+ human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
166
+ human_prompt = human_prompt_path.read_text()
167
+
168
+ return PromptTemplate.from_template(
169
+ f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{human_prompt} [/INST]"
170
+ )
171
+
172
+
173
+ class Llama3PromptEngine(PromptEngine):
174
+ # see https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
175
+ # /#special-tokens-used-with-meta-llama-3
176
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
177
+ system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
178
+ system_prompt = system_prompt_path.read_text()
179
+
180
+ human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
181
+ human_prompt = human_prompt_path.read_text()
182
+
183
+ return PromptTemplate.from_template(
184
+ f"<|begin_of_text|>"
185
+ f"<|start_header_id|>"
186
+ f"system"
187
+ f"<|end_header_id|>"
188
+ f"\n\n{system_prompt}"
189
+ f"<|eot_id|>"
190
+ f"<|start_header_id|>"
191
+ f"user"
192
+ f"<|end_header_id|>"
193
+ f"\n\n{human_prompt}"
194
+ f"<|eot_id|>"
195
+ f"<|start_header_id|>"
196
+ f"assistant"
197
+ f"<|end_header_id|>"
198
+ f"\n\n"
199
+ )
200
+
201
+
202
+ class CoherePromptEngine(PromptEngine):
203
+ # see https://docs.cohere.com/docs/prompting-command-r
204
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
205
+ system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
206
+ system_prompt = system_prompt_path.read_text()
207
+
208
+ human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
209
+ human_prompt = human_prompt_path.read_text()
210
+
211
+ return PromptTemplate.from_template(
212
+ f"<BOS_TOKEN>"
213
+ f"<|START_OF_TURN_TOKEN|>"
214
+ f"<|SYSTEM_TOKEN|>"
215
+ f"{system_prompt}"
216
+ f"<|END_OF_TURN_TOKEN|>"
217
+ f"<|START_OF_TURN_TOKEN|>"
218
+ f"<|USER_TOKEN|>"
219
+ f"{human_prompt}"
220
+ f"<|END_OF_TURN_TOKEN|>"
221
+ )