janus-llm 4.2.0__py3-none-any.whl → 4.3.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/__main__.py +1 -1
- janus/_tests/evaluator_tests/EvalReadMe.md +85 -0
- janus/_tests/evaluator_tests/incose_tests/incose_large_test.json +39 -0
- janus/_tests/evaluator_tests/incose_tests/incose_small_test.json +17 -0
- janus/_tests/evaluator_tests/inline_comment_tests/mumps_inline_comment_test.m +71 -0
- janus/_tests/test_cli.py +3 -2
- janus/cli/aggregate.py +135 -0
- janus/cli/cli.py +111 -0
- janus/cli/constants.py +43 -0
- janus/cli/database.py +289 -0
- janus/cli/diagram.py +178 -0
- janus/cli/document.py +174 -0
- janus/cli/embedding.py +122 -0
- janus/cli/llm.py +187 -0
- janus/cli/partition.py +125 -0
- janus/cli/self_eval.py +149 -0
- janus/cli/translate.py +183 -0
- janus/converter/__init__.py +1 -1
- janus/converter/_tests/test_translate.py +2 -0
- janus/converter/converter.py +129 -92
- janus/converter/document.py +21 -14
- janus/converter/evaluate.py +237 -4
- janus/converter/translate.py +3 -3
- janus/embedding/collections.py +1 -1
- janus/language/alc/_tests/alc.asm +3779 -0
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +9 -4
- janus/language/binary/_tests/hello.bin +0 -0
- janus/language/block.py +47 -12
- janus/language/file.py +1 -1
- janus/language/mumps/_tests/mumps.m +235 -0
- janus/language/splitter.py +31 -23
- janus/language/treesitter/_tests/languages/fortran.f90 +416 -0
- janus/language/treesitter/_tests/languages/ibmhlasm.asm +16 -0
- janus/language/treesitter/_tests/languages/matlab.m +225 -0
- janus/language/treesitter/treesitter.py +9 -1
- janus/llm/models_info.py +26 -13
- janus/metrics/_tests/asm_test_file.asm +10 -0
- janus/metrics/_tests/mumps_test_file.m +6 -0
- janus/metrics/_tests/test_treesitter_metrics.py +1 -1
- janus/metrics/prompts/clarity.txt +8 -0
- janus/metrics/prompts/completeness.txt +16 -0
- janus/metrics/prompts/faithfulness.txt +10 -0
- janus/metrics/prompts/hallucination.txt +16 -0
- janus/metrics/prompts/quality.txt +8 -0
- janus/metrics/prompts/readability.txt +16 -0
- janus/metrics/prompts/usefulness.txt +16 -0
- janus/parsers/code_parser.py +4 -4
- janus/parsers/doc_parser.py +12 -9
- janus/parsers/eval_parsers/incose_parser.py +134 -0
- janus/parsers/eval_parsers/inline_comment_parser.py +112 -0
- janus/parsers/parser.py +7 -0
- janus/parsers/partition_parser.py +47 -13
- janus/parsers/reqs_parser.py +8 -5
- janus/parsers/uml.py +5 -4
- janus/prompts/prompt.py +2 -2
- janus/prompts/templates/README.md +30 -0
- janus/prompts/templates/basic_aggregation/human.txt +6 -0
- janus/prompts/templates/basic_aggregation/system.txt +1 -0
- janus/prompts/templates/basic_refinement/human.txt +14 -0
- janus/prompts/templates/basic_refinement/system.txt +1 -0
- janus/prompts/templates/diagram/human.txt +9 -0
- janus/prompts/templates/diagram/system.txt +1 -0
- janus/prompts/templates/diagram_with_documentation/human.txt +15 -0
- janus/prompts/templates/diagram_with_documentation/system.txt +1 -0
- janus/prompts/templates/document/human.txt +10 -0
- janus/prompts/templates/document/system.txt +1 -0
- janus/prompts/templates/document_cloze/human.txt +11 -0
- janus/prompts/templates/document_cloze/system.txt +1 -0
- janus/prompts/templates/document_cloze/variables.json +4 -0
- janus/prompts/templates/document_cloze/variables_asm.json +4 -0
- janus/prompts/templates/document_inline/human.txt +13 -0
- janus/prompts/templates/eval_prompts/incose/human.txt +32 -0
- janus/prompts/templates/eval_prompts/incose/system.txt +1 -0
- janus/prompts/templates/eval_prompts/incose/variables.json +3 -0
- janus/prompts/templates/eval_prompts/inline_comments/human.txt +49 -0
- janus/prompts/templates/eval_prompts/inline_comments/system.txt +1 -0
- janus/prompts/templates/eval_prompts/inline_comments/variables.json +3 -0
- janus/prompts/templates/micromanaged_mumps_v1.0/human.txt +23 -0
- janus/prompts/templates/micromanaged_mumps_v1.0/system.txt +3 -0
- janus/prompts/templates/micromanaged_mumps_v2.0/human.txt +28 -0
- janus/prompts/templates/micromanaged_mumps_v2.0/system.txt +3 -0
- janus/prompts/templates/micromanaged_mumps_v2.1/human.txt +29 -0
- janus/prompts/templates/micromanaged_mumps_v2.1/system.txt +3 -0
- janus/prompts/templates/multidocument/human.txt +15 -0
- janus/prompts/templates/multidocument/system.txt +1 -0
- janus/prompts/templates/partition/human.txt +22 -0
- janus/prompts/templates/partition/system.txt +1 -0
- janus/prompts/templates/partition/variables.json +4 -0
- janus/prompts/templates/pseudocode/human.txt +7 -0
- janus/prompts/templates/pseudocode/system.txt +7 -0
- janus/prompts/templates/refinement/fix_exceptions/human.txt +19 -0
- janus/prompts/templates/refinement/fix_exceptions/system.txt +1 -0
- janus/prompts/templates/refinement/format/code_format/human.txt +12 -0
- janus/prompts/templates/refinement/format/code_format/system.txt +1 -0
- janus/prompts/templates/refinement/format/requirements_format/human.txt +14 -0
- janus/prompts/templates/refinement/format/requirements_format/system.txt +1 -0
- janus/prompts/templates/refinement/hallucination/human.txt +13 -0
- janus/prompts/templates/refinement/hallucination/system.txt +1 -0
- janus/prompts/templates/refinement/reflection/human.txt +15 -0
- janus/prompts/templates/refinement/reflection/incose/human.txt +26 -0
- janus/prompts/templates/refinement/reflection/incose/system.txt +1 -0
- janus/prompts/templates/refinement/reflection/incose_deduplicate/human.txt +16 -0
- janus/prompts/templates/refinement/reflection/incose_deduplicate/system.txt +1 -0
- janus/prompts/templates/refinement/reflection/system.txt +1 -0
- janus/prompts/templates/refinement/revision/human.txt +16 -0
- janus/prompts/templates/refinement/revision/incose/human.txt +16 -0
- janus/prompts/templates/refinement/revision/incose/system.txt +1 -0
- janus/prompts/templates/refinement/revision/incose_deduplicate/human.txt +17 -0
- janus/prompts/templates/refinement/revision/incose_deduplicate/system.txt +1 -0
- janus/prompts/templates/refinement/revision/system.txt +1 -0
- janus/prompts/templates/refinement/uml/alc_fix_variables/human.txt +15 -0
- janus/prompts/templates/refinement/uml/alc_fix_variables/system.txt +2 -0
- janus/prompts/templates/refinement/uml/fix_connections/human.txt +15 -0
- janus/prompts/templates/refinement/uml/fix_connections/system.txt +2 -0
- janus/prompts/templates/requirements/human.txt +13 -0
- janus/prompts/templates/requirements/system.txt +2 -0
- janus/prompts/templates/retrieval/language_docs/human.txt +10 -0
- janus/prompts/templates/retrieval/language_docs/system.txt +1 -0
- janus/prompts/templates/simple/human.txt +16 -0
- janus/prompts/templates/simple/system.txt +3 -0
- janus/refiners/format.py +49 -0
- janus/refiners/refiner.py +143 -4
- janus/utils/enums.py +140 -111
- janus/utils/logger.py +2 -0
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/METADATA +7 -7
- janus_llm-4.3.5.dist-info/RECORD +210 -0
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/WHEEL +1 -1
- janus_llm-4.3.5.dist-info/entry_points.txt +3 -0
- janus/cli.py +0 -1343
- janus_llm-4.2.0.dist-info/RECORD +0 -113
- janus_llm-4.2.0.dist-info/entry_points.txt +0 -3
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/LICENSE +0 -0
janus/llm/models_info.py
CHANGED
@@ -6,9 +6,13 @@ from typing import Callable, Protocol, TypeVar
|
|
6
6
|
from dotenv import load_dotenv
|
7
7
|
from langchain_community.llms import HuggingFaceTextGenInference
|
8
8
|
from langchain_core.runnables import Runnable
|
9
|
-
from langchain_openai import AzureChatOpenAI
|
9
|
+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
10
10
|
|
11
|
-
from janus.llm.model_callbacks import
|
11
|
+
from janus.llm.model_callbacks import (
|
12
|
+
COST_PER_1K_TOKENS,
|
13
|
+
azure_model_reroutes,
|
14
|
+
openai_model_reroutes,
|
15
|
+
)
|
12
16
|
from janus.prompts.prompt import (
|
13
17
|
ChatGptPromptEngine,
|
14
18
|
ClaudePromptEngine,
|
@@ -46,6 +50,7 @@ except ImportError:
|
|
46
50
|
ModelType = TypeVar(
|
47
51
|
"ModelType",
|
48
52
|
AzureChatOpenAI,
|
53
|
+
ChatOpenAI,
|
49
54
|
HuggingFaceTextGenInference,
|
50
55
|
Bedrock,
|
51
56
|
BedrockChat,
|
@@ -127,7 +132,7 @@ bedrock_models = [
|
|
127
132
|
all_models = [*azure_models, *bedrock_models]
|
128
133
|
|
129
134
|
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
130
|
-
|
135
|
+
"OpenAI": ChatOpenAI,
|
131
136
|
"HuggingFace": HuggingFaceTextGenInference,
|
132
137
|
"Azure": AzureChatOpenAI,
|
133
138
|
"Bedrock": Bedrock,
|
@@ -137,7 +142,7 @@ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
|
137
142
|
|
138
143
|
|
139
144
|
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
140
|
-
|
145
|
+
**{m: ChatGptPromptEngine for m in openai_models},
|
141
146
|
**{m: ChatGptPromptEngine for m in azure_models},
|
142
147
|
**{m: ClaudePromptEngine for m in claude_models},
|
143
148
|
**{m: Llama2PromptEngine for m in llama2_models},
|
@@ -148,7 +153,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
148
153
|
}
|
149
154
|
|
150
155
|
MODEL_ID_TO_LONG_ID = {
|
151
|
-
|
156
|
+
**{m: mr for m, mr in openai_model_reroutes.items()},
|
152
157
|
**{m: mr for m, mr in azure_model_reroutes.items()},
|
153
158
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
154
159
|
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
@@ -181,7 +186,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
181
186
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
182
187
|
|
183
188
|
MODEL_TYPES: dict[str, PromptEngine] = {
|
184
|
-
|
189
|
+
**{m: "OpenAI" for m in openai_models},
|
185
190
|
**{m: "Azure" for m in azure_models},
|
186
191
|
**{m: "BedrockChat" for m in bedrock_models},
|
187
192
|
}
|
@@ -243,6 +248,7 @@ def load_model(model_id) -> JanusModel:
|
|
243
248
|
token_limit = model_config["token_limit"]
|
244
249
|
input_token_cost = model_config["model_cost"]["input"]
|
245
250
|
output_token_cost = model_config["model_cost"]["output"]
|
251
|
+
input_token_proportion = model_config["input_token_proportion"]
|
246
252
|
|
247
253
|
elif model_id in DEFAULT_MODELS:
|
248
254
|
model_id = model_id
|
@@ -253,6 +259,7 @@ def load_model(model_id) -> JanusModel:
|
|
253
259
|
token_limit = 0
|
254
260
|
input_token_cost = 0.0
|
255
261
|
output_token_cost = 0.0
|
262
|
+
input_token_proportion = 0.4
|
256
263
|
if model_long_id in TOKEN_LIMITS:
|
257
264
|
token_limit = TOKEN_LIMITS[model_long_id]
|
258
265
|
if model_long_id in COST_PER_1K_TOKENS:
|
@@ -282,22 +289,22 @@ def load_model(model_id) -> JanusModel:
|
|
282
289
|
elif model_type_name == "OpenAI":
|
283
290
|
model_args.update(
|
284
291
|
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
285
|
-
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
286
292
|
)
|
287
293
|
# log.warning("Do NOT use this model in sensitive environments!")
|
288
294
|
# log.warning("If you would like to cancel, please press Ctrl+C.")
|
289
295
|
# log.warning("Waiting 10 seconds...")
|
290
296
|
# Give enough time for the user to read the warnings and cancel
|
291
297
|
# time.sleep(10)
|
292
|
-
raise DeprecationWarning("OpenAI models are no longer supported.")
|
298
|
+
# raise DeprecationWarning("OpenAI models are no longer supported.")
|
293
299
|
|
294
300
|
elif model_type_name == "Azure":
|
295
301
|
model_args.update(
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
302
|
+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
303
|
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
304
|
+
api_version=os.getenv("OPENAI_API_VERSION", "2024-02-01"),
|
305
|
+
azure_deployment=model_id,
|
306
|
+
request_timeout=3600,
|
307
|
+
max_tokens=4096,
|
301
308
|
)
|
302
309
|
|
303
310
|
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
@@ -305,15 +312,20 @@ def load_model(model_id) -> JanusModel:
|
|
305
312
|
|
306
313
|
class JanusModel(model_type):
|
307
314
|
model_id: str
|
315
|
+
# model_name is for LangChain compatibility
|
316
|
+
# It searches for `self.model_name` when counting tokens
|
317
|
+
model_name: str
|
308
318
|
short_model_id: str
|
309
319
|
model_type_name: str
|
310
320
|
token_limit: int
|
321
|
+
input_token_proportion: float
|
311
322
|
input_token_cost: float
|
312
323
|
output_token_cost: float
|
313
324
|
prompt_engine: type[PromptEngine]
|
314
325
|
|
315
326
|
model_args.update(
|
316
327
|
model_id=MODEL_ID_TO_LONG_ID[model_id],
|
328
|
+
model_name=model_id, # This is for LangChain compatibility
|
317
329
|
short_model_id=model_id,
|
318
330
|
)
|
319
331
|
|
@@ -322,6 +334,7 @@ def load_model(model_id) -> JanusModel:
|
|
322
334
|
token_limit=token_limit,
|
323
335
|
input_token_cost=input_token_cost,
|
324
336
|
output_token_cost=output_token_cost,
|
337
|
+
input_token_proportion=input_token_proportion,
|
325
338
|
prompt_engine=prompt_engine,
|
326
339
|
**model_args,
|
327
340
|
)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
Based on the following target written in the {language} programming language, how would you rate the code clarity of the target on a scale of integers from 1 to 10? Higher is better.
|
2
|
+
|
3
|
+
Think through your answer before selecting a rating with the following format:
|
4
|
+
|
5
|
+
Target: the target code
|
6
|
+
{format_instructions}
|
7
|
+
|
8
|
+
Target: {target}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
Use the following rubric to evaluate the target written in the {language} programming language:
|
2
|
+
|
3
|
+
Rubric:
|
4
|
+
Does the comment address all capabilities of the relevant source code?
|
5
|
+
|
6
|
+
10 - All essential functionality is documented.
|
7
|
+
6-9 - Most essential functionality is documented.
|
8
|
+
2-5 - Little essential functionality is documented.
|
9
|
+
1 - No essential functionality is documented.
|
10
|
+
|
11
|
+
Think through your answer before selecting a rating with the following format:
|
12
|
+
|
13
|
+
Target: the target code
|
14
|
+
{format_instructions}
|
15
|
+
|
16
|
+
Target: {target}
|
@@ -0,0 +1,10 @@
|
|
1
|
+
Based on the following target and reference written in the {language} programming language, how would you rate the faithfulness of the target to the original reference on a scale of integers from 1 to 10? Higher is better.
|
2
|
+
|
3
|
+
Think through your answer before selecting a rating with the following format:
|
4
|
+
|
5
|
+
Target: the target code
|
6
|
+
Reference: the reference code that we are judging the target against
|
7
|
+
{format_instructions}
|
8
|
+
|
9
|
+
Target: {target}
|
10
|
+
Reference: {reference}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
Use the following rubric to evaluate the target written in the {language} programming language:
|
2
|
+
|
3
|
+
Rubric:
|
4
|
+
Does the comment provide true information?
|
5
|
+
|
6
|
+
10 - The comment provides only true information.
|
7
|
+
6-9 - The comment provides mostly true information.
|
8
|
+
2-5 - The comment provides mostly untrue information.
|
9
|
+
1 - The comment is completely untrue.
|
10
|
+
|
11
|
+
Think through your answer before selecting a rating with the following format:
|
12
|
+
|
13
|
+
Target: the target code
|
14
|
+
{format_instructions}
|
15
|
+
|
16
|
+
Target: {target}
|
@@ -0,0 +1,8 @@
|
|
1
|
+
Based on the following target written in the {language} programming language, how would you rate the code quality of the target on a scale of integers from 1 to 10? Higher is better.
|
2
|
+
|
3
|
+
Think through your answer before selecting a rating with the following format:
|
4
|
+
|
5
|
+
Target: the target code
|
6
|
+
{format_instructions}
|
7
|
+
|
8
|
+
Target: {target}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
Use the following rubric to evaluate the target written in the {language} programming language:
|
2
|
+
|
3
|
+
Rubric:
|
4
|
+
Is the comment clear to read?
|
5
|
+
|
6
|
+
10 - The comment is well-written.
|
7
|
+
6-9 - The comment has few problems.
|
8
|
+
2-5 - The comment has many problems.
|
9
|
+
1 - The comment is unreadable.
|
10
|
+
|
11
|
+
Think through your answer before selecting a rating with the following format:
|
12
|
+
|
13
|
+
Target: the target code
|
14
|
+
{format_instructions}
|
15
|
+
|
16
|
+
Target: {target}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
Use the following rubric to evaluate the target written in the {language} programming language:
|
2
|
+
|
3
|
+
Rubric:
|
4
|
+
Is the comment useful?
|
5
|
+
|
6
|
+
10 - The comment helps an expert programmer understand the code better.
|
7
|
+
6-9 - The comment helps an average programmer understand the code better.
|
8
|
+
2-5 - The comment documents only trivial functionality.
|
9
|
+
1 - The comment is not useful at any level.
|
10
|
+
|
11
|
+
Think through your answer before selecting a rating with the following format:
|
12
|
+
|
13
|
+
Target: the target code
|
14
|
+
{format_instructions}
|
15
|
+
|
16
|
+
Target: {target}
|
janus/parsers/code_parser.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
import re
|
2
2
|
|
3
|
-
from langchain_core.exceptions import OutputParserException
|
4
3
|
from langchain_core.messages import BaseMessage
|
5
4
|
|
6
|
-
from janus.parsers.parser import JanusParser
|
5
|
+
from janus.parsers.parser import JanusParser, JanusParserException
|
7
6
|
from janus.utils.logger import create_logger
|
8
7
|
|
9
8
|
log = create_logger(__name__)
|
@@ -18,8 +17,9 @@ class CodeParser(JanusParser):
|
|
18
17
|
pattern = rf"```[^\S\r\n]*(?:{self.language}[^\S\r\n]*)?\n?(.*?)\n*```"
|
19
18
|
code = re.search(pattern, text, re.DOTALL)
|
20
19
|
if code is None:
|
21
|
-
raise
|
22
|
-
|
20
|
+
raise JanusParserException(
|
21
|
+
text,
|
22
|
+
"Code not find code between triple square brackets",
|
23
23
|
)
|
24
24
|
return str(code.group(1))
|
25
25
|
|
janus/parsers/doc_parser.py
CHANGED
@@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage
|
|
8
8
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
9
9
|
|
10
10
|
from janus.language.block import CodeBlock
|
11
|
-
from janus.parsers.parser import JanusParser
|
11
|
+
from janus.parsers.parser import JanusParser, JanusParserException
|
12
12
|
from janus.utils.logger import create_logger
|
13
13
|
|
14
14
|
log = create_logger(__name__)
|
@@ -86,7 +86,7 @@ class MultiDocumentationParser(JanusParser, PydanticOutputParser):
|
|
86
86
|
return str(self.__class__.name)
|
87
87
|
|
88
88
|
|
89
|
-
class
|
89
|
+
class ClozeDocumentationParser(JanusParser):
|
90
90
|
expected_keys: set[str]
|
91
91
|
|
92
92
|
def __init__(self):
|
@@ -107,11 +107,12 @@ class MadlibsDocumentationParser(JanusParser):
|
|
107
107
|
obj = parse_json_markdown(text)
|
108
108
|
except json.JSONDecodeError as e:
|
109
109
|
log.debug(f"Invalid JSON object. Output:\n{text}")
|
110
|
-
raise
|
110
|
+
raise JanusParserException(text, f"Got invalid JSON object. Error: {e}")
|
111
111
|
|
112
112
|
if not isinstance(obj, dict):
|
113
|
-
raise
|
114
|
-
|
113
|
+
raise JanusParserException(
|
114
|
+
text,
|
115
|
+
f"Got invalid return object. Expected a dictionary, but got {type(obj)}",
|
115
116
|
)
|
116
117
|
|
117
118
|
seen_keys = set(obj.keys())
|
@@ -122,9 +123,10 @@ class MadlibsDocumentationParser(JanusParser):
|
|
122
123
|
if invalid_keys:
|
123
124
|
log.debug(f"Invalid keys: {invalid_keys}")
|
124
125
|
log.debug(f"Missing keys: {missing_keys}")
|
125
|
-
raise
|
126
|
+
raise JanusParserException(
|
127
|
+
text,
|
126
128
|
f"Got invalid return object. Missing the following expected "
|
127
|
-
f"keys: {missing_keys}"
|
129
|
+
f"keys: {missing_keys}",
|
128
130
|
)
|
129
131
|
|
130
132
|
for key in invalid_keys:
|
@@ -132,9 +134,10 @@ class MadlibsDocumentationParser(JanusParser):
|
|
132
134
|
|
133
135
|
for value in obj.values():
|
134
136
|
if not isinstance(value, str):
|
135
|
-
raise
|
137
|
+
raise JanusParserException(
|
138
|
+
text,
|
136
139
|
f"Got invalid return object. Expected all string values,"
|
137
|
-
f' but got type "{type(value)}"'
|
140
|
+
f' but got type "{type(value)}"',
|
138
141
|
)
|
139
142
|
|
140
143
|
return json.dumps(obj)
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
from langchain.output_parsers import PydanticOutputParser
|
7
|
+
from langchain_core.exceptions import OutputParserException
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class Criteria(BaseModel):
|
20
|
+
reasoning: str = Field(description="A short explanation for the given assessment")
|
21
|
+
score: str = Field("A simple `pass` or `fail`")
|
22
|
+
|
23
|
+
@validator("score")
|
24
|
+
def score_is_valid(cls, v: str):
|
25
|
+
v = v.lower().strip()
|
26
|
+
if v not in {"pass", "fail"}:
|
27
|
+
raise OutputParserException("Score must be either 'pass' or 'fail'")
|
28
|
+
return v
|
29
|
+
|
30
|
+
|
31
|
+
class Requirement(BaseModel):
|
32
|
+
requirement_id: str = Field(description="The 8-character comment ID")
|
33
|
+
requirement: str = Field(description="The original requirement being evaluated")
|
34
|
+
C1: Criteria
|
35
|
+
C2: Criteria
|
36
|
+
C3: Criteria
|
37
|
+
C4: Criteria
|
38
|
+
C5: Criteria
|
39
|
+
C6: Criteria
|
40
|
+
C7: Criteria
|
41
|
+
C8: Criteria
|
42
|
+
C9: Criteria
|
43
|
+
|
44
|
+
|
45
|
+
class RequirementList(BaseModel):
|
46
|
+
__root__: List[Requirement] = Field(
|
47
|
+
description=(
|
48
|
+
"A list of requirement evaluations. Each element should include"
|
49
|
+
" the requirement's 8-character ID in the `requirement_id` field,"
|
50
|
+
" the original requirement in the 'requirement' field, "
|
51
|
+
" and nine score objects corresponding to each criterion."
|
52
|
+
)
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class IncoseParser(JanusParser, PydanticOutputParser):
|
57
|
+
requirements: dict[str, str]
|
58
|
+
|
59
|
+
def __init__(self):
|
60
|
+
PydanticOutputParser.__init__(
|
61
|
+
self,
|
62
|
+
pydantic_object=RequirementList,
|
63
|
+
requirements={},
|
64
|
+
)
|
65
|
+
|
66
|
+
def parse_input(self, block: CodeBlock) -> str:
|
67
|
+
# TODO: Perform comment stripping/placeholding here rather than in script
|
68
|
+
text = super().parse_input(block)
|
69
|
+
RNG.seed(text)
|
70
|
+
|
71
|
+
obj = json.loads(text)
|
72
|
+
|
73
|
+
# For some reason requirements objects are in a double list?
|
74
|
+
reqs = obj["requirements"]
|
75
|
+
|
76
|
+
# Generate a unique ID for each requirement (ensure they are unique)
|
77
|
+
req_ids = set()
|
78
|
+
while len(req_ids) < len(reqs):
|
79
|
+
req_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
|
80
|
+
|
81
|
+
self.requirements = dict(zip(req_ids, reqs))
|
82
|
+
reqs_str = "\n\n".join(
|
83
|
+
f"Requirement {rid} : {req}" for rid, req in self.requirements.items()
|
84
|
+
)
|
85
|
+
obj["requirements"] = reqs_str
|
86
|
+
return json.dumps(obj)
|
87
|
+
|
88
|
+
def parse(self, text: str | BaseMessage) -> str:
|
89
|
+
if isinstance(text, BaseMessage):
|
90
|
+
text = str(text.content)
|
91
|
+
|
92
|
+
# Strip everything outside the JSON object
|
93
|
+
begin, end = text.find("["), text.rfind("]")
|
94
|
+
text = text[begin : end + 1]
|
95
|
+
|
96
|
+
try:
|
97
|
+
out: RequirementList = super().parse(text)
|
98
|
+
except json.JSONDecodeError as e:
|
99
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
100
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
101
|
+
|
102
|
+
evals: dict[str, dict] = {c.requirement_id: c.dict() for c in out.__root__}
|
103
|
+
|
104
|
+
seen_keys = set(evals.keys())
|
105
|
+
expected_keys = set(self.requirements.keys())
|
106
|
+
missing_keys = expected_keys.difference(seen_keys)
|
107
|
+
invalid_keys = seen_keys.difference(expected_keys)
|
108
|
+
if missing_keys:
|
109
|
+
log.debug(f"Missing keys: {missing_keys}")
|
110
|
+
if invalid_keys:
|
111
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
112
|
+
log.debug(f"Missing keys: {missing_keys}")
|
113
|
+
raise OutputParserException(
|
114
|
+
f"Got invalid return object. Missing the following expected "
|
115
|
+
f"keys: {missing_keys}"
|
116
|
+
)
|
117
|
+
|
118
|
+
for key in invalid_keys:
|
119
|
+
del evals[key]
|
120
|
+
|
121
|
+
for rid in evals.keys():
|
122
|
+
evals[rid]["requirement"] = self.requirements[rid]
|
123
|
+
evals[rid].pop("requirement_id")
|
124
|
+
|
125
|
+
return json.dumps(evals)
|
126
|
+
|
127
|
+
def parse_combined_output(self, text: str) -> str:
|
128
|
+
if not text.strip():
|
129
|
+
return str({})
|
130
|
+
objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
|
131
|
+
output_obj = {}
|
132
|
+
for obj in objs:
|
133
|
+
output_obj.update(obj)
|
134
|
+
return json.dumps(output_obj)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.messages import BaseMessage
|
8
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, conint
|
9
|
+
|
10
|
+
from janus.language.block import CodeBlock
|
11
|
+
from janus.parsers.parser import JanusParser
|
12
|
+
from janus.utils.logger import create_logger
|
13
|
+
|
14
|
+
log = create_logger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class Criteria(BaseModel):
|
18
|
+
reasoning: str = Field(description="A short explanation for the given score")
|
19
|
+
# Constrained to an integer between 1 and 4
|
20
|
+
score: conint(ge=1, le=4) = Field( # type: ignore
|
21
|
+
description="An integer score between 1 and 4 (inclusive), 4 being the best"
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
class Comment(BaseModel):
|
26
|
+
comment_id: str = Field(description="The 8-character comment ID")
|
27
|
+
completeness: Criteria = Field(description="The completeness of the comment")
|
28
|
+
hallucination: Criteria = Field(description="The factualness of the comment")
|
29
|
+
readability: Criteria = Field(description="The readability of the comment")
|
30
|
+
usefulness: Criteria = Field(description="The usefulness of the comment")
|
31
|
+
|
32
|
+
|
33
|
+
class CommentList(BaseModel):
|
34
|
+
__root__: list[Comment] = Field(
|
35
|
+
description=(
|
36
|
+
"A list of inline comment evaluations. Each element should include"
|
37
|
+
" the comment's 8-character ID in the `comment_id` field, and four"
|
38
|
+
" score objects corresponding to each metric (`completeness`,"
|
39
|
+
" `hallucination`, `readability`, and `usefulness`)."
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class InlineCommentParser(JanusParser, PydanticOutputParser):
|
45
|
+
comments: dict[str, str]
|
46
|
+
|
47
|
+
def __init__(self):
|
48
|
+
PydanticOutputParser.__init__(
|
49
|
+
self,
|
50
|
+
pydantic_object=CommentList,
|
51
|
+
comments=[],
|
52
|
+
)
|
53
|
+
|
54
|
+
def parse_input(self, block: CodeBlock) -> str:
|
55
|
+
# TODO: Perform comment stripping/placeholding here rather than in script
|
56
|
+
text = super().parse_input(block)
|
57
|
+
self.comments = dict(
|
58
|
+
re.findall(
|
59
|
+
r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
|
60
|
+
text,
|
61
|
+
flags=re.MULTILINE,
|
62
|
+
)
|
63
|
+
)
|
64
|
+
return text
|
65
|
+
|
66
|
+
def parse(self, text: str | BaseMessage) -> str:
|
67
|
+
if isinstance(text, BaseMessage):
|
68
|
+
text = str(text.content)
|
69
|
+
|
70
|
+
# Strip everything outside the JSON object
|
71
|
+
begin, end = text.find("["), text.rfind("]")
|
72
|
+
text = text[begin : end + 1]
|
73
|
+
|
74
|
+
try:
|
75
|
+
out: CommentList = super().parse(text)
|
76
|
+
except json.JSONDecodeError as e:
|
77
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
78
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
79
|
+
|
80
|
+
evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
|
81
|
+
|
82
|
+
seen_keys = set(evals.keys())
|
83
|
+
expected_keys = set(self.comments.keys())
|
84
|
+
missing_keys = expected_keys.difference(seen_keys)
|
85
|
+
invalid_keys = seen_keys.difference(expected_keys)
|
86
|
+
if missing_keys:
|
87
|
+
log.debug(f"Missing keys: {missing_keys}")
|
88
|
+
if invalid_keys:
|
89
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
90
|
+
log.debug(f"Missing keys: {missing_keys}")
|
91
|
+
raise OutputParserException(
|
92
|
+
f"Got invalid return object. Missing the following expected "
|
93
|
+
f"keys: {missing_keys}"
|
94
|
+
)
|
95
|
+
|
96
|
+
for key in invalid_keys:
|
97
|
+
del evals[key]
|
98
|
+
|
99
|
+
for cid in evals.keys():
|
100
|
+
evals[cid]["comment"] = self.comments[cid]
|
101
|
+
evals[cid].pop("comment_id")
|
102
|
+
|
103
|
+
return json.dumps(evals)
|
104
|
+
|
105
|
+
def parse_combined_output(self, text: str) -> str:
|
106
|
+
if not text.strip():
|
107
|
+
return str({})
|
108
|
+
objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
|
109
|
+
output_obj = {}
|
110
|
+
for obj in objs:
|
111
|
+
output_obj.update(obj)
|
112
|
+
return json.dumps(output_obj)
|
janus/parsers/parser.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from langchain.schema.output_parser import BaseOutputParser
|
2
|
+
from langchain_core.exceptions import OutputParserException
|
2
3
|
from langchain_core.messages import BaseMessage
|
3
4
|
from langchain_core.output_parsers import StrOutputParser
|
4
5
|
|
@@ -49,3 +50,9 @@ class GenericParser(JanusParser, StrOutputParser):
|
|
49
50
|
|
50
51
|
def get_format_instructions(self) -> str:
|
51
52
|
return "Output should be a string"
|
53
|
+
|
54
|
+
|
55
|
+
class JanusParserException(OutputParserException):
|
56
|
+
def __init__(self, unparsed_output, *args, **kwargs):
|
57
|
+
self.unparsed_output = unparsed_output
|
58
|
+
super().__init__(*args, **kwargs)
|