janus-llm 4.2.0__py3-none-any.whl → 4.3.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (134) hide show
  1. janus/__init__.py +1 -1
  2. janus/__main__.py +1 -1
  3. janus/_tests/evaluator_tests/EvalReadMe.md +85 -0
  4. janus/_tests/evaluator_tests/incose_tests/incose_large_test.json +39 -0
  5. janus/_tests/evaluator_tests/incose_tests/incose_small_test.json +17 -0
  6. janus/_tests/evaluator_tests/inline_comment_tests/mumps_inline_comment_test.m +71 -0
  7. janus/_tests/test_cli.py +3 -2
  8. janus/cli/aggregate.py +135 -0
  9. janus/cli/cli.py +111 -0
  10. janus/cli/constants.py +43 -0
  11. janus/cli/database.py +289 -0
  12. janus/cli/diagram.py +178 -0
  13. janus/cli/document.py +174 -0
  14. janus/cli/embedding.py +122 -0
  15. janus/cli/llm.py +187 -0
  16. janus/cli/partition.py +125 -0
  17. janus/cli/self_eval.py +149 -0
  18. janus/cli/translate.py +183 -0
  19. janus/converter/__init__.py +1 -1
  20. janus/converter/_tests/test_translate.py +2 -0
  21. janus/converter/converter.py +129 -92
  22. janus/converter/document.py +21 -14
  23. janus/converter/evaluate.py +237 -4
  24. janus/converter/translate.py +3 -3
  25. janus/embedding/collections.py +1 -1
  26. janus/language/alc/_tests/alc.asm +3779 -0
  27. janus/language/alc/_tests/test_alc.py +1 -1
  28. janus/language/alc/alc.py +9 -4
  29. janus/language/binary/_tests/hello.bin +0 -0
  30. janus/language/block.py +47 -12
  31. janus/language/file.py +1 -1
  32. janus/language/mumps/_tests/mumps.m +235 -0
  33. janus/language/splitter.py +31 -23
  34. janus/language/treesitter/_tests/languages/fortran.f90 +416 -0
  35. janus/language/treesitter/_tests/languages/ibmhlasm.asm +16 -0
  36. janus/language/treesitter/_tests/languages/matlab.m +225 -0
  37. janus/language/treesitter/treesitter.py +9 -1
  38. janus/llm/models_info.py +26 -13
  39. janus/metrics/_tests/asm_test_file.asm +10 -0
  40. janus/metrics/_tests/mumps_test_file.m +6 -0
  41. janus/metrics/_tests/test_treesitter_metrics.py +1 -1
  42. janus/metrics/prompts/clarity.txt +8 -0
  43. janus/metrics/prompts/completeness.txt +16 -0
  44. janus/metrics/prompts/faithfulness.txt +10 -0
  45. janus/metrics/prompts/hallucination.txt +16 -0
  46. janus/metrics/prompts/quality.txt +8 -0
  47. janus/metrics/prompts/readability.txt +16 -0
  48. janus/metrics/prompts/usefulness.txt +16 -0
  49. janus/parsers/code_parser.py +4 -4
  50. janus/parsers/doc_parser.py +12 -9
  51. janus/parsers/eval_parsers/incose_parser.py +134 -0
  52. janus/parsers/eval_parsers/inline_comment_parser.py +112 -0
  53. janus/parsers/parser.py +7 -0
  54. janus/parsers/partition_parser.py +47 -13
  55. janus/parsers/reqs_parser.py +8 -5
  56. janus/parsers/uml.py +5 -4
  57. janus/prompts/prompt.py +2 -2
  58. janus/prompts/templates/README.md +30 -0
  59. janus/prompts/templates/basic_aggregation/human.txt +6 -0
  60. janus/prompts/templates/basic_aggregation/system.txt +1 -0
  61. janus/prompts/templates/basic_refinement/human.txt +14 -0
  62. janus/prompts/templates/basic_refinement/system.txt +1 -0
  63. janus/prompts/templates/diagram/human.txt +9 -0
  64. janus/prompts/templates/diagram/system.txt +1 -0
  65. janus/prompts/templates/diagram_with_documentation/human.txt +15 -0
  66. janus/prompts/templates/diagram_with_documentation/system.txt +1 -0
  67. janus/prompts/templates/document/human.txt +10 -0
  68. janus/prompts/templates/document/system.txt +1 -0
  69. janus/prompts/templates/document_cloze/human.txt +11 -0
  70. janus/prompts/templates/document_cloze/system.txt +1 -0
  71. janus/prompts/templates/document_cloze/variables.json +4 -0
  72. janus/prompts/templates/document_cloze/variables_asm.json +4 -0
  73. janus/prompts/templates/document_inline/human.txt +13 -0
  74. janus/prompts/templates/eval_prompts/incose/human.txt +32 -0
  75. janus/prompts/templates/eval_prompts/incose/system.txt +1 -0
  76. janus/prompts/templates/eval_prompts/incose/variables.json +3 -0
  77. janus/prompts/templates/eval_prompts/inline_comments/human.txt +49 -0
  78. janus/prompts/templates/eval_prompts/inline_comments/system.txt +1 -0
  79. janus/prompts/templates/eval_prompts/inline_comments/variables.json +3 -0
  80. janus/prompts/templates/micromanaged_mumps_v1.0/human.txt +23 -0
  81. janus/prompts/templates/micromanaged_mumps_v1.0/system.txt +3 -0
  82. janus/prompts/templates/micromanaged_mumps_v2.0/human.txt +28 -0
  83. janus/prompts/templates/micromanaged_mumps_v2.0/system.txt +3 -0
  84. janus/prompts/templates/micromanaged_mumps_v2.1/human.txt +29 -0
  85. janus/prompts/templates/micromanaged_mumps_v2.1/system.txt +3 -0
  86. janus/prompts/templates/multidocument/human.txt +15 -0
  87. janus/prompts/templates/multidocument/system.txt +1 -0
  88. janus/prompts/templates/partition/human.txt +22 -0
  89. janus/prompts/templates/partition/system.txt +1 -0
  90. janus/prompts/templates/partition/variables.json +4 -0
  91. janus/prompts/templates/pseudocode/human.txt +7 -0
  92. janus/prompts/templates/pseudocode/system.txt +7 -0
  93. janus/prompts/templates/refinement/fix_exceptions/human.txt +19 -0
  94. janus/prompts/templates/refinement/fix_exceptions/system.txt +1 -0
  95. janus/prompts/templates/refinement/format/code_format/human.txt +12 -0
  96. janus/prompts/templates/refinement/format/code_format/system.txt +1 -0
  97. janus/prompts/templates/refinement/format/requirements_format/human.txt +14 -0
  98. janus/prompts/templates/refinement/format/requirements_format/system.txt +1 -0
  99. janus/prompts/templates/refinement/hallucination/human.txt +13 -0
  100. janus/prompts/templates/refinement/hallucination/system.txt +1 -0
  101. janus/prompts/templates/refinement/reflection/human.txt +15 -0
  102. janus/prompts/templates/refinement/reflection/incose/human.txt +26 -0
  103. janus/prompts/templates/refinement/reflection/incose/system.txt +1 -0
  104. janus/prompts/templates/refinement/reflection/incose_deduplicate/human.txt +16 -0
  105. janus/prompts/templates/refinement/reflection/incose_deduplicate/system.txt +1 -0
  106. janus/prompts/templates/refinement/reflection/system.txt +1 -0
  107. janus/prompts/templates/refinement/revision/human.txt +16 -0
  108. janus/prompts/templates/refinement/revision/incose/human.txt +16 -0
  109. janus/prompts/templates/refinement/revision/incose/system.txt +1 -0
  110. janus/prompts/templates/refinement/revision/incose_deduplicate/human.txt +17 -0
  111. janus/prompts/templates/refinement/revision/incose_deduplicate/system.txt +1 -0
  112. janus/prompts/templates/refinement/revision/system.txt +1 -0
  113. janus/prompts/templates/refinement/uml/alc_fix_variables/human.txt +15 -0
  114. janus/prompts/templates/refinement/uml/alc_fix_variables/system.txt +2 -0
  115. janus/prompts/templates/refinement/uml/fix_connections/human.txt +15 -0
  116. janus/prompts/templates/refinement/uml/fix_connections/system.txt +2 -0
  117. janus/prompts/templates/requirements/human.txt +13 -0
  118. janus/prompts/templates/requirements/system.txt +2 -0
  119. janus/prompts/templates/retrieval/language_docs/human.txt +10 -0
  120. janus/prompts/templates/retrieval/language_docs/system.txt +1 -0
  121. janus/prompts/templates/simple/human.txt +16 -0
  122. janus/prompts/templates/simple/system.txt +3 -0
  123. janus/refiners/format.py +49 -0
  124. janus/refiners/refiner.py +143 -4
  125. janus/utils/enums.py +140 -111
  126. janus/utils/logger.py +2 -0
  127. {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/METADATA +7 -7
  128. janus_llm-4.3.5.dist-info/RECORD +210 -0
  129. {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/WHEEL +1 -1
  130. janus_llm-4.3.5.dist-info/entry_points.txt +3 -0
  131. janus/cli.py +0 -1343
  132. janus_llm-4.2.0.dist-info/RECORD +0 -113
  133. janus_llm-4.2.0.dist-info/entry_points.txt +0 -3
  134. {janus_llm-4.2.0.dist-info → janus_llm-4.3.5.dist-info}/LICENSE +0 -0
janus/llm/models_info.py CHANGED
@@ -6,9 +6,13 @@ from typing import Callable, Protocol, TypeVar
6
6
  from dotenv import load_dotenv
7
7
  from langchain_community.llms import HuggingFaceTextGenInference
8
8
  from langchain_core.runnables import Runnable
9
- from langchain_openai import AzureChatOpenAI
9
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
10
10
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
11
+ from janus.llm.model_callbacks import (
12
+ COST_PER_1K_TOKENS,
13
+ azure_model_reroutes,
14
+ openai_model_reroutes,
15
+ )
12
16
  from janus.prompts.prompt import (
13
17
  ChatGptPromptEngine,
14
18
  ClaudePromptEngine,
@@ -46,6 +50,7 @@ except ImportError:
46
50
  ModelType = TypeVar(
47
51
  "ModelType",
48
52
  AzureChatOpenAI,
53
+ ChatOpenAI,
49
54
  HuggingFaceTextGenInference,
50
55
  Bedrock,
51
56
  BedrockChat,
@@ -127,7 +132,7 @@ bedrock_models = [
127
132
  all_models = [*azure_models, *bedrock_models]
128
133
 
129
134
  MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
130
- # "OpenAI": ChatOpenAI,
135
+ "OpenAI": ChatOpenAI,
131
136
  "HuggingFace": HuggingFaceTextGenInference,
132
137
  "Azure": AzureChatOpenAI,
133
138
  "Bedrock": Bedrock,
@@ -137,7 +142,7 @@ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
137
142
 
138
143
 
139
144
  MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
140
- # **{m: ChatGptPromptEngine for m in openai_models},
145
+ **{m: ChatGptPromptEngine for m in openai_models},
141
146
  **{m: ChatGptPromptEngine for m in azure_models},
142
147
  **{m: ClaudePromptEngine for m in claude_models},
143
148
  **{m: Llama2PromptEngine for m in llama2_models},
@@ -148,7 +153,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
148
153
  }
149
154
 
150
155
  MODEL_ID_TO_LONG_ID = {
151
- # **{m: mr for m, mr in openai_model_reroutes.items()},
156
+ **{m: mr for m, mr in openai_model_reroutes.items()},
152
157
  **{m: mr for m, mr in azure_model_reroutes.items()},
153
158
  "bedrock-claude-v2": "anthropic.claude-v2",
154
159
  "bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
@@ -181,7 +186,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
181
186
  MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
182
187
 
183
188
  MODEL_TYPES: dict[str, PromptEngine] = {
184
- # **{m: "OpenAI" for m in openai_models},
189
+ **{m: "OpenAI" for m in openai_models},
185
190
  **{m: "Azure" for m in azure_models},
186
191
  **{m: "BedrockChat" for m in bedrock_models},
187
192
  }
@@ -243,6 +248,7 @@ def load_model(model_id) -> JanusModel:
243
248
  token_limit = model_config["token_limit"]
244
249
  input_token_cost = model_config["model_cost"]["input"]
245
250
  output_token_cost = model_config["model_cost"]["output"]
251
+ input_token_proportion = model_config["input_token_proportion"]
246
252
 
247
253
  elif model_id in DEFAULT_MODELS:
248
254
  model_id = model_id
@@ -253,6 +259,7 @@ def load_model(model_id) -> JanusModel:
253
259
  token_limit = 0
254
260
  input_token_cost = 0.0
255
261
  output_token_cost = 0.0
262
+ input_token_proportion = 0.4
256
263
  if model_long_id in TOKEN_LIMITS:
257
264
  token_limit = TOKEN_LIMITS[model_long_id]
258
265
  if model_long_id in COST_PER_1K_TOKENS:
@@ -282,22 +289,22 @@ def load_model(model_id) -> JanusModel:
282
289
  elif model_type_name == "OpenAI":
283
290
  model_args.update(
284
291
  openai_api_key=str(os.getenv("OPENAI_API_KEY")),
285
- openai_organization=str(os.getenv("OPENAI_ORG_ID")),
286
292
  )
287
293
  # log.warning("Do NOT use this model in sensitive environments!")
288
294
  # log.warning("If you would like to cancel, please press Ctrl+C.")
289
295
  # log.warning("Waiting 10 seconds...")
290
296
  # Give enough time for the user to read the warnings and cancel
291
297
  # time.sleep(10)
292
- raise DeprecationWarning("OpenAI models are no longer supported.")
298
+ # raise DeprecationWarning("OpenAI models are no longer supported.")
293
299
 
294
300
  elif model_type_name == "Azure":
295
301
  model_args.update(
296
- {
297
- "api_key": os.getenv("AZURE_OPENAI_API_KEY"),
298
- "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
299
- "api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
300
- }
302
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
303
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
304
+ api_version=os.getenv("OPENAI_API_VERSION", "2024-02-01"),
305
+ azure_deployment=model_id,
306
+ request_timeout=3600,
307
+ max_tokens=4096,
301
308
  )
302
309
 
303
310
  model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
@@ -305,15 +312,20 @@ def load_model(model_id) -> JanusModel:
305
312
 
306
313
  class JanusModel(model_type):
307
314
  model_id: str
315
+ # model_name is for LangChain compatibility
316
+ # It searches for `self.model_name` when counting tokens
317
+ model_name: str
308
318
  short_model_id: str
309
319
  model_type_name: str
310
320
  token_limit: int
321
+ input_token_proportion: float
311
322
  input_token_cost: float
312
323
  output_token_cost: float
313
324
  prompt_engine: type[PromptEngine]
314
325
 
315
326
  model_args.update(
316
327
  model_id=MODEL_ID_TO_LONG_ID[model_id],
328
+ model_name=model_id, # This is for LangChain compatibility
317
329
  short_model_id=model_id,
318
330
  )
319
331
 
@@ -322,6 +334,7 @@ def load_model(model_id) -> JanusModel:
322
334
  token_limit=token_limit,
323
335
  input_token_cost=input_token_cost,
324
336
  output_token_cost=output_token_cost,
337
+ input_token_proportion=input_token_proportion,
325
338
  prompt_engine=prompt_engine,
326
339
  **model_args,
327
340
  )
@@ -0,0 +1,10 @@
1
+ NAME OPA OPSA,OPSB
2
+ OPB OPSC,OPSC REMARK
3
+ NAME OPC OPSA,OPSB
4
+ OPD OPSA,OPSB REMARK2
5
+ B OPSA
6
+ OPD
7
+ B OPSB
8
+ NAME OPC OPSA,OPSB
9
+ OPC
10
+ OPC OPSA,OPSB
@@ -0,0 +1,6 @@
1
+ FUNC(a, b) ; apples
2
+ set apples=8
3
+ write a,!
4
+ write a,!
5
+ if abc=70 set f=1
6
+ quit 0
@@ -3,7 +3,7 @@ from pathlib import Path
3
3
 
4
4
  from typer.testing import CliRunner
5
5
 
6
- from janus.cli import app
6
+ from janus.cli.cli import app
7
7
  from janus.metrics.complexity_metrics import (
8
8
  TreeSitterMetric,
9
9
  cyclomatic_complexity,
@@ -0,0 +1,8 @@
1
+ Based on the following target written in the {language} programming language, how would you rate the code clarity of the target on a scale of integers from 1 to 10? Higher is better.
2
+
3
+ Think through your answer before selecting a rating with the following format:
4
+
5
+ Target: the target code
6
+ {format_instructions}
7
+
8
+ Target: {target}
@@ -0,0 +1,16 @@
1
+ Use the following rubric to evaluate the target written in the {language} programming language:
2
+
3
+ Rubric:
4
+ Does the comment address all capabilities of the relevant source code?
5
+
6
+ 10 - All essential functionality is documented.
7
+ 6-9 - Most essential functionality is documented.
8
+ 2-5 - Little essential functionality is documented.
9
+ 1 - No essential functionality is documented.
10
+
11
+ Think through your answer before selecting a rating with the following format:
12
+
13
+ Target: the target code
14
+ {format_instructions}
15
+
16
+ Target: {target}
@@ -0,0 +1,10 @@
1
+ Based on the following target and reference written in the {language} programming language, how would you rate the faithfulness of the target to the original reference on a scale of integers from 1 to 10? Higher is better.
2
+
3
+ Think through your answer before selecting a rating with the following format:
4
+
5
+ Target: the target code
6
+ Reference: the reference code that we are judging the target against
7
+ {format_instructions}
8
+
9
+ Target: {target}
10
+ Reference: {reference}
@@ -0,0 +1,16 @@
1
+ Use the following rubric to evaluate the target written in the {language} programming language:
2
+
3
+ Rubric:
4
+ Does the comment provide true information?
5
+
6
+ 10 - The comment provides only true information.
7
+ 6-9 - The comment provides mostly true information.
8
+ 2-5 - The comment provides mostly untrue information.
9
+ 1 - The comment is completely untrue.
10
+
11
+ Think through your answer before selecting a rating with the following format:
12
+
13
+ Target: the target code
14
+ {format_instructions}
15
+
16
+ Target: {target}
@@ -0,0 +1,8 @@
1
+ Based on the following target written in the {language} programming language, how would you rate the code quality of the target on a scale of integers from 1 to 10? Higher is better.
2
+
3
+ Think through your answer before selecting a rating with the following format:
4
+
5
+ Target: the target code
6
+ {format_instructions}
7
+
8
+ Target: {target}
@@ -0,0 +1,16 @@
1
+ Use the following rubric to evaluate the target written in the {language} programming language:
2
+
3
+ Rubric:
4
+ Is the comment clear to read?
5
+
6
+ 10 - The comment is well-written.
7
+ 6-9 - The comment has few problems.
8
+ 2-5 - The comment has many problems.
9
+ 1 - The comment is unreadable.
10
+
11
+ Think through your answer before selecting a rating with the following format:
12
+
13
+ Target: the target code
14
+ {format_instructions}
15
+
16
+ Target: {target}
@@ -0,0 +1,16 @@
1
+ Use the following rubric to evaluate the target written in the {language} programming language:
2
+
3
+ Rubric:
4
+ Is the comment useful?
5
+
6
+ 10 - The comment helps an expert programmer understand the code better.
7
+ 6-9 - The comment helps an average programmer understand the code better.
8
+ 2-5 - The comment documents only trivial functionality.
9
+ 1 - The comment is not useful at any level.
10
+
11
+ Think through your answer before selecting a rating with the following format:
12
+
13
+ Target: the target code
14
+ {format_instructions}
15
+
16
+ Target: {target}
@@ -1,9 +1,8 @@
1
1
  import re
2
2
 
3
- from langchain_core.exceptions import OutputParserException
4
3
  from langchain_core.messages import BaseMessage
5
4
 
6
- from janus.parsers.parser import JanusParser
5
+ from janus.parsers.parser import JanusParser, JanusParserException
7
6
  from janus.utils.logger import create_logger
8
7
 
9
8
  log = create_logger(__name__)
@@ -18,8 +17,9 @@ class CodeParser(JanusParser):
18
17
  pattern = rf"```[^\S\r\n]*(?:{self.language}[^\S\r\n]*)?\n?(.*?)\n*```"
19
18
  code = re.search(pattern, text, re.DOTALL)
20
19
  if code is None:
21
- raise OutputParserException(
22
- "Code not find code between triple square brackets"
20
+ raise JanusParserException(
21
+ text,
22
+ "Code not find code between triple square brackets",
23
23
  )
24
24
  return str(code.group(1))
25
25
 
@@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage
8
8
  from langchain_core.pydantic_v1 import BaseModel, Field
9
9
 
10
10
  from janus.language.block import CodeBlock
11
- from janus.parsers.parser import JanusParser
11
+ from janus.parsers.parser import JanusParser, JanusParserException
12
12
  from janus.utils.logger import create_logger
13
13
 
14
14
  log = create_logger(__name__)
@@ -86,7 +86,7 @@ class MultiDocumentationParser(JanusParser, PydanticOutputParser):
86
86
  return str(self.__class__.name)
87
87
 
88
88
 
89
- class MadlibsDocumentationParser(JanusParser):
89
+ class ClozeDocumentationParser(JanusParser):
90
90
  expected_keys: set[str]
91
91
 
92
92
  def __init__(self):
@@ -107,11 +107,12 @@ class MadlibsDocumentationParser(JanusParser):
107
107
  obj = parse_json_markdown(text)
108
108
  except json.JSONDecodeError as e:
109
109
  log.debug(f"Invalid JSON object. Output:\n{text}")
110
- raise OutputParserException(f"Got invalid JSON object. Error: {e}")
110
+ raise JanusParserException(text, f"Got invalid JSON object. Error: {e}")
111
111
 
112
112
  if not isinstance(obj, dict):
113
- raise OutputParserException(
114
- f"Got invalid return object. Expected a dictionary, but got {type(obj)}"
113
+ raise JanusParserException(
114
+ text,
115
+ f"Got invalid return object. Expected a dictionary, but got {type(obj)}",
115
116
  )
116
117
 
117
118
  seen_keys = set(obj.keys())
@@ -122,9 +123,10 @@ class MadlibsDocumentationParser(JanusParser):
122
123
  if invalid_keys:
123
124
  log.debug(f"Invalid keys: {invalid_keys}")
124
125
  log.debug(f"Missing keys: {missing_keys}")
125
- raise OutputParserException(
126
+ raise JanusParserException(
127
+ text,
126
128
  f"Got invalid return object. Missing the following expected "
127
- f"keys: {missing_keys}"
129
+ f"keys: {missing_keys}",
128
130
  )
129
131
 
130
132
  for key in invalid_keys:
@@ -132,9 +134,10 @@ class MadlibsDocumentationParser(JanusParser):
132
134
 
133
135
  for value in obj.values():
134
136
  if not isinstance(value, str):
135
- raise OutputParserException(
137
+ raise JanusParserException(
138
+ text,
136
139
  f"Got invalid return object. Expected all string values,"
137
- f' but got type "{type(value)}"'
140
+ f' but got type "{type(value)}"',
138
141
  )
139
142
 
140
143
  return json.dumps(obj)
@@ -0,0 +1,134 @@
1
+ import json
2
+ import random
3
+ import uuid
4
+ from typing import List
5
+
6
+ from langchain.output_parsers import PydanticOutputParser
7
+ from langchain_core.exceptions import OutputParserException
8
+ from langchain_core.messages import BaseMessage
9
+ from langchain_core.pydantic_v1 import BaseModel, Field, validator
10
+
11
+ from janus.language.block import CodeBlock
12
+ from janus.parsers.parser import JanusParser
13
+ from janus.utils.logger import create_logger
14
+
15
+ log = create_logger(__name__)
16
+ RNG = random.Random()
17
+
18
+
19
+ class Criteria(BaseModel):
20
+ reasoning: str = Field(description="A short explanation for the given assessment")
21
+ score: str = Field("A simple `pass` or `fail`")
22
+
23
+ @validator("score")
24
+ def score_is_valid(cls, v: str):
25
+ v = v.lower().strip()
26
+ if v not in {"pass", "fail"}:
27
+ raise OutputParserException("Score must be either 'pass' or 'fail'")
28
+ return v
29
+
30
+
31
+ class Requirement(BaseModel):
32
+ requirement_id: str = Field(description="The 8-character comment ID")
33
+ requirement: str = Field(description="The original requirement being evaluated")
34
+ C1: Criteria
35
+ C2: Criteria
36
+ C3: Criteria
37
+ C4: Criteria
38
+ C5: Criteria
39
+ C6: Criteria
40
+ C7: Criteria
41
+ C8: Criteria
42
+ C9: Criteria
43
+
44
+
45
+ class RequirementList(BaseModel):
46
+ __root__: List[Requirement] = Field(
47
+ description=(
48
+ "A list of requirement evaluations. Each element should include"
49
+ " the requirement's 8-character ID in the `requirement_id` field,"
50
+ " the original requirement in the 'requirement' field, "
51
+ " and nine score objects corresponding to each criterion."
52
+ )
53
+ )
54
+
55
+
56
+ class IncoseParser(JanusParser, PydanticOutputParser):
57
+ requirements: dict[str, str]
58
+
59
+ def __init__(self):
60
+ PydanticOutputParser.__init__(
61
+ self,
62
+ pydantic_object=RequirementList,
63
+ requirements={},
64
+ )
65
+
66
+ def parse_input(self, block: CodeBlock) -> str:
67
+ # TODO: Perform comment stripping/placeholding here rather than in script
68
+ text = super().parse_input(block)
69
+ RNG.seed(text)
70
+
71
+ obj = json.loads(text)
72
+
73
+ # For some reason requirements objects are in a double list?
74
+ reqs = obj["requirements"]
75
+
76
+ # Generate a unique ID for each requirement (ensure they are unique)
77
+ req_ids = set()
78
+ while len(req_ids) < len(reqs):
79
+ req_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
80
+
81
+ self.requirements = dict(zip(req_ids, reqs))
82
+ reqs_str = "\n\n".join(
83
+ f"Requirement {rid} : {req}" for rid, req in self.requirements.items()
84
+ )
85
+ obj["requirements"] = reqs_str
86
+ return json.dumps(obj)
87
+
88
+ def parse(self, text: str | BaseMessage) -> str:
89
+ if isinstance(text, BaseMessage):
90
+ text = str(text.content)
91
+
92
+ # Strip everything outside the JSON object
93
+ begin, end = text.find("["), text.rfind("]")
94
+ text = text[begin : end + 1]
95
+
96
+ try:
97
+ out: RequirementList = super().parse(text)
98
+ except json.JSONDecodeError as e:
99
+ log.debug(f"Invalid JSON object. Output:\n{text}")
100
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
101
+
102
+ evals: dict[str, dict] = {c.requirement_id: c.dict() for c in out.__root__}
103
+
104
+ seen_keys = set(evals.keys())
105
+ expected_keys = set(self.requirements.keys())
106
+ missing_keys = expected_keys.difference(seen_keys)
107
+ invalid_keys = seen_keys.difference(expected_keys)
108
+ if missing_keys:
109
+ log.debug(f"Missing keys: {missing_keys}")
110
+ if invalid_keys:
111
+ log.debug(f"Invalid keys: {invalid_keys}")
112
+ log.debug(f"Missing keys: {missing_keys}")
113
+ raise OutputParserException(
114
+ f"Got invalid return object. Missing the following expected "
115
+ f"keys: {missing_keys}"
116
+ )
117
+
118
+ for key in invalid_keys:
119
+ del evals[key]
120
+
121
+ for rid in evals.keys():
122
+ evals[rid]["requirement"] = self.requirements[rid]
123
+ evals[rid].pop("requirement_id")
124
+
125
+ return json.dumps(evals)
126
+
127
+ def parse_combined_output(self, text: str) -> str:
128
+ if not text.strip():
129
+ return str({})
130
+ objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
131
+ output_obj = {}
132
+ for obj in objs:
133
+ output_obj.update(obj)
134
+ return json.dumps(output_obj)
@@ -0,0 +1,112 @@
1
+ import json
2
+ import re
3
+ from typing import Any
4
+
5
+ from langchain.output_parsers import PydanticOutputParser
6
+ from langchain_core.exceptions import OutputParserException
7
+ from langchain_core.messages import BaseMessage
8
+ from langchain_core.pydantic_v1 import BaseModel, Field, conint
9
+
10
+ from janus.language.block import CodeBlock
11
+ from janus.parsers.parser import JanusParser
12
+ from janus.utils.logger import create_logger
13
+
14
+ log = create_logger(__name__)
15
+
16
+
17
+ class Criteria(BaseModel):
18
+ reasoning: str = Field(description="A short explanation for the given score")
19
+ # Constrained to an integer between 1 and 4
20
+ score: conint(ge=1, le=4) = Field( # type: ignore
21
+ description="An integer score between 1 and 4 (inclusive), 4 being the best"
22
+ )
23
+
24
+
25
+ class Comment(BaseModel):
26
+ comment_id: str = Field(description="The 8-character comment ID")
27
+ completeness: Criteria = Field(description="The completeness of the comment")
28
+ hallucination: Criteria = Field(description="The factualness of the comment")
29
+ readability: Criteria = Field(description="The readability of the comment")
30
+ usefulness: Criteria = Field(description="The usefulness of the comment")
31
+
32
+
33
+ class CommentList(BaseModel):
34
+ __root__: list[Comment] = Field(
35
+ description=(
36
+ "A list of inline comment evaluations. Each element should include"
37
+ " the comment's 8-character ID in the `comment_id` field, and four"
38
+ " score objects corresponding to each metric (`completeness`,"
39
+ " `hallucination`, `readability`, and `usefulness`)."
40
+ )
41
+ )
42
+
43
+
44
+ class InlineCommentParser(JanusParser, PydanticOutputParser):
45
+ comments: dict[str, str]
46
+
47
+ def __init__(self):
48
+ PydanticOutputParser.__init__(
49
+ self,
50
+ pydantic_object=CommentList,
51
+ comments=[],
52
+ )
53
+
54
+ def parse_input(self, block: CodeBlock) -> str:
55
+ # TODO: Perform comment stripping/placeholding here rather than in script
56
+ text = super().parse_input(block)
57
+ self.comments = dict(
58
+ re.findall(
59
+ r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
60
+ text,
61
+ flags=re.MULTILINE,
62
+ )
63
+ )
64
+ return text
65
+
66
+ def parse(self, text: str | BaseMessage) -> str:
67
+ if isinstance(text, BaseMessage):
68
+ text = str(text.content)
69
+
70
+ # Strip everything outside the JSON object
71
+ begin, end = text.find("["), text.rfind("]")
72
+ text = text[begin : end + 1]
73
+
74
+ try:
75
+ out: CommentList = super().parse(text)
76
+ except json.JSONDecodeError as e:
77
+ log.debug(f"Invalid JSON object. Output:\n{text}")
78
+ raise OutputParserException(f"Got invalid JSON object. Error: {e}")
79
+
80
+ evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
81
+
82
+ seen_keys = set(evals.keys())
83
+ expected_keys = set(self.comments.keys())
84
+ missing_keys = expected_keys.difference(seen_keys)
85
+ invalid_keys = seen_keys.difference(expected_keys)
86
+ if missing_keys:
87
+ log.debug(f"Missing keys: {missing_keys}")
88
+ if invalid_keys:
89
+ log.debug(f"Invalid keys: {invalid_keys}")
90
+ log.debug(f"Missing keys: {missing_keys}")
91
+ raise OutputParserException(
92
+ f"Got invalid return object. Missing the following expected "
93
+ f"keys: {missing_keys}"
94
+ )
95
+
96
+ for key in invalid_keys:
97
+ del evals[key]
98
+
99
+ for cid in evals.keys():
100
+ evals[cid]["comment"] = self.comments[cid]
101
+ evals[cid].pop("comment_id")
102
+
103
+ return json.dumps(evals)
104
+
105
+ def parse_combined_output(self, text: str) -> str:
106
+ if not text.strip():
107
+ return str({})
108
+ objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
109
+ output_obj = {}
110
+ for obj in objs:
111
+ output_obj.update(obj)
112
+ return json.dumps(output_obj)
janus/parsers/parser.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from langchain.schema.output_parser import BaseOutputParser
2
+ from langchain_core.exceptions import OutputParserException
2
3
  from langchain_core.messages import BaseMessage
3
4
  from langchain_core.output_parsers import StrOutputParser
4
5
 
@@ -49,3 +50,9 @@ class GenericParser(JanusParser, StrOutputParser):
49
50
 
50
51
  def get_format_instructions(self) -> str:
51
52
  return "Output should be a string"
53
+
54
+
55
+ class JanusParserException(OutputParserException):
56
+ def __init__(self, unparsed_output, *args, **kwargs):
57
+ self.unparsed_output = unparsed_output
58
+ super().__init__(*args, **kwargs)