janus-llm 4.2.0__py3-none-any.whl → 4.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +150 -5
- janus/converter/converter.py +1 -0
- janus/converter/evaluate.py +230 -4
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +9 -4
- janus/language/splitter.py +31 -23
- janus/language/treesitter/treesitter.py +9 -1
- janus/llm/models_info.py +17 -12
- janus/parsers/eval_parsers/incose_parser.py +134 -0
- janus/parsers/eval_parsers/inline_comment_parser.py +112 -0
- janus/parsers/partition_parser.py +41 -9
- janus/refiners/refiner.py +30 -0
- janus/utils/enums.py +14 -0
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.1.dist-info}/METADATA +1 -1
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.1.dist-info}/RECORD +19 -17
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.1.dist-info}/LICENSE +0 -0
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.1.dist-info}/WHEEL +0 -0
- {janus_llm-4.2.0.dist-info → janus_llm-4.3.1.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "4.
|
8
|
+
__version__ = "4.3.1"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -19,6 +19,7 @@ from janus.converter.aggregator import Aggregator
|
|
19
19
|
from janus.converter.converter import Converter
|
20
20
|
from janus.converter.diagram import DiagramGenerator
|
21
21
|
from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
|
22
|
+
from janus.converter.evaluate import InlineCommentEvaluator, RequirementEvaluator
|
22
23
|
from janus.converter.partition import Partitioner
|
23
24
|
from janus.converter.requirements import RequirementsDocumenter
|
24
25
|
from janus.converter.translate import Translator
|
@@ -127,7 +128,7 @@ embedding = typer.Typer(
|
|
127
128
|
|
128
129
|
def version_callback(value: bool) -> None:
|
129
130
|
if value:
|
130
|
-
from
|
131
|
+
from . import __version__ as version
|
131
132
|
|
132
133
|
print(f"Janus CLI [blue]v{version}[/blue]")
|
133
134
|
raise typer.Exit()
|
@@ -655,6 +656,16 @@ def partition(
|
|
655
656
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
656
657
|
),
|
657
658
|
] = "file",
|
659
|
+
refiner_types: Annotated[
|
660
|
+
list[str],
|
661
|
+
typer.Option(
|
662
|
+
"-r",
|
663
|
+
"--refiner",
|
664
|
+
help="List of refiner types to use. Add -r for each refiner to use in\
|
665
|
+
refinement chain",
|
666
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
667
|
+
),
|
668
|
+
] = ["JanusRefiner"],
|
658
669
|
max_tokens: Annotated[
|
659
670
|
int,
|
660
671
|
typer.Option(
|
@@ -673,6 +684,7 @@ def partition(
|
|
673
684
|
),
|
674
685
|
] = 8192,
|
675
686
|
):
|
687
|
+
refiner_types = [REFINERS[r] for r in refiner_types]
|
676
688
|
model_arguments = dict(temperature=temperature)
|
677
689
|
kwargs = dict(
|
678
690
|
model=llm_name,
|
@@ -681,6 +693,7 @@ def partition(
|
|
681
693
|
max_prompts=max_prompts,
|
682
694
|
max_tokens=max_tokens,
|
683
695
|
splitter_type=splitter_type,
|
696
|
+
refiner_types=refiner_types,
|
684
697
|
partition_token_limit=partition_token_limit,
|
685
698
|
)
|
686
699
|
partitioner = Partitioner(**kwargs)
|
@@ -815,6 +828,139 @@ def diagram(
|
|
815
828
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
816
829
|
|
817
830
|
|
831
|
+
@app.command(
|
832
|
+
help="LLM self evaluation",
|
833
|
+
no_args_is_help=True,
|
834
|
+
)
|
835
|
+
def llm_self_eval(
|
836
|
+
input_dir: Annotated[
|
837
|
+
Path,
|
838
|
+
typer.Option(
|
839
|
+
"--input",
|
840
|
+
"-i",
|
841
|
+
help="The directory containing the source code to be evaluated. "
|
842
|
+
"The files should all be in one flat directory.",
|
843
|
+
),
|
844
|
+
],
|
845
|
+
language: Annotated[
|
846
|
+
str,
|
847
|
+
typer.Option(
|
848
|
+
"--language",
|
849
|
+
"-l",
|
850
|
+
help="The language of the source code.",
|
851
|
+
click_type=click.Choice(sorted(LANGUAGES)),
|
852
|
+
),
|
853
|
+
],
|
854
|
+
output_dir: Annotated[
|
855
|
+
Path,
|
856
|
+
typer.Option(
|
857
|
+
"--output-dir", "-o", help="The directory to store the evaluations in."
|
858
|
+
),
|
859
|
+
],
|
860
|
+
llm_name: Annotated[
|
861
|
+
str,
|
862
|
+
typer.Option(
|
863
|
+
"--llm",
|
864
|
+
"-L",
|
865
|
+
help="The custom name of the model set with 'janus llm add'.",
|
866
|
+
),
|
867
|
+
] = "gpt-4o",
|
868
|
+
evaluation_type: Annotated[
|
869
|
+
str,
|
870
|
+
typer.Option(
|
871
|
+
"--evaluation-type",
|
872
|
+
"-e",
|
873
|
+
help="Type of output to evaluate.",
|
874
|
+
click_type=click.Choice(["incose", "comments"]),
|
875
|
+
),
|
876
|
+
] = "incose",
|
877
|
+
max_prompts: Annotated[
|
878
|
+
int,
|
879
|
+
typer.Option(
|
880
|
+
"--max-prompts",
|
881
|
+
"-m",
|
882
|
+
help="The maximum number of times to prompt a model on one functional block "
|
883
|
+
"before exiting the application. This is to prevent wasting too much money.",
|
884
|
+
),
|
885
|
+
] = 10,
|
886
|
+
overwrite: Annotated[
|
887
|
+
bool,
|
888
|
+
typer.Option(
|
889
|
+
"--overwrite/--preserve",
|
890
|
+
help="Whether to overwrite existing files in the output directory",
|
891
|
+
),
|
892
|
+
] = False,
|
893
|
+
temperature: Annotated[
|
894
|
+
float,
|
895
|
+
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
896
|
+
] = 0.7,
|
897
|
+
collection: Annotated[
|
898
|
+
str,
|
899
|
+
typer.Option(
|
900
|
+
"--collection",
|
901
|
+
"-c",
|
902
|
+
help="If set, will put the translated result into a Chroma DB "
|
903
|
+
"collection with the name provided.",
|
904
|
+
),
|
905
|
+
] = None,
|
906
|
+
splitter_type: Annotated[
|
907
|
+
str,
|
908
|
+
typer.Option(
|
909
|
+
"-S",
|
910
|
+
"--splitter",
|
911
|
+
help="Name of custom splitter to use",
|
912
|
+
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
913
|
+
),
|
914
|
+
] = "file",
|
915
|
+
refiner_types: Annotated[
|
916
|
+
list[str],
|
917
|
+
typer.Option(
|
918
|
+
"-r",
|
919
|
+
"--refiner",
|
920
|
+
help="List of refiner types to use. Add -r for each refiner to use in\
|
921
|
+
refinement chain",
|
922
|
+
click_type=click.Choice(list(REFINERS.keys())),
|
923
|
+
),
|
924
|
+
] = ["JanusRefiner"],
|
925
|
+
eval_items_per_request: Annotated[
|
926
|
+
int,
|
927
|
+
typer.Option(
|
928
|
+
"--eval-items-per-request",
|
929
|
+
"-rc",
|
930
|
+
help="The maximum number of evaluation items per request",
|
931
|
+
),
|
932
|
+
] = None,
|
933
|
+
max_tokens: Annotated[
|
934
|
+
int,
|
935
|
+
typer.Option(
|
936
|
+
"--max-tokens",
|
937
|
+
"-M",
|
938
|
+
help="The maximum number of tokens the model will take in. "
|
939
|
+
"If unspecificed, model's default max will be used.",
|
940
|
+
),
|
941
|
+
] = None,
|
942
|
+
):
|
943
|
+
model_arguments = dict(temperature=temperature)
|
944
|
+
refiner_types = [REFINERS[r] for r in refiner_types]
|
945
|
+
kwargs = dict(
|
946
|
+
eval_items_per_request=eval_items_per_request,
|
947
|
+
model=llm_name,
|
948
|
+
model_arguments=model_arguments,
|
949
|
+
source_language=language,
|
950
|
+
max_prompts=max_prompts,
|
951
|
+
max_tokens=max_tokens,
|
952
|
+
splitter_type=splitter_type,
|
953
|
+
refiner_types=refiner_types,
|
954
|
+
)
|
955
|
+
# Setting parser type here
|
956
|
+
if evaluation_type == "incose":
|
957
|
+
evaluator = RequirementEvaluator(**kwargs)
|
958
|
+
elif evaluation_type == "comments":
|
959
|
+
evaluator = InlineCommentEvaluator(**kwargs)
|
960
|
+
|
961
|
+
evaluator.translate(input_dir, output_dir, overwrite, collection)
|
962
|
+
|
963
|
+
|
818
964
|
@db.command("init", help="Connect to or create a database.")
|
819
965
|
def db_init(
|
820
966
|
path: Annotated[
|
@@ -1116,13 +1262,12 @@ def llm_add(
|
|
1116
1262
|
show_choices=False,
|
1117
1263
|
)
|
1118
1264
|
params = dict(
|
1119
|
-
|
1120
|
-
model_name=MODEL_ID_TO_LONG_ID[model_id],
|
1265
|
+
model_name=model_name,
|
1121
1266
|
temperature=0.7,
|
1122
1267
|
n=1,
|
1123
1268
|
)
|
1124
|
-
max_tokens = TOKEN_LIMITS[
|
1125
|
-
model_cost = COST_PER_1K_TOKENS[
|
1269
|
+
max_tokens = TOKEN_LIMITS[model_name]
|
1270
|
+
model_cost = COST_PER_1K_TOKENS[model_name]
|
1126
1271
|
cfg = {
|
1127
1272
|
"model_type": model_type,
|
1128
1273
|
"model_id": model_id,
|
janus/converter/converter.py
CHANGED
@@ -464,6 +464,7 @@ class Converter:
|
|
464
464
|
for in_path, out_path in in_out_pairs:
|
465
465
|
# Translate the file, skip it if there's a rate limit error
|
466
466
|
try:
|
467
|
+
log.info(f"Processing {in_path.relative_to(input_directory)}")
|
467
468
|
out_block = self.translate_file(in_path)
|
468
469
|
total_cost += out_block.total_cost
|
469
470
|
except RateLimitError:
|
janus/converter/evaluate.py
CHANGED
@@ -1,15 +1,241 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
|
6
|
+
|
1
7
|
from janus.converter.converter import Converter
|
8
|
+
from janus.language.block import TranslatedCodeBlock
|
2
9
|
from janus.language.combine import JsonCombiner
|
3
|
-
from janus.parsers.
|
10
|
+
from janus.parsers.eval_parsers.incose_parser import IncoseParser
|
11
|
+
from janus.parsers.eval_parsers.inline_comment_parser import InlineCommentParser
|
4
12
|
from janus.utils.logger import create_logger
|
5
13
|
|
6
14
|
log = create_logger(__name__)
|
7
15
|
|
8
16
|
|
9
17
|
class Evaluator(Converter):
|
10
|
-
|
18
|
+
"""Evaluator
|
19
|
+
|
20
|
+
A class that performs an LLM self evaluation"
|
21
|
+
"on an input target, with an associated prompt.
|
22
|
+
|
23
|
+
Current valid evaluation types:
|
24
|
+
['incose', 'comments']
|
25
|
+
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, **kwargs) -> None:
|
29
|
+
"""Initialize the Evaluator class
|
30
|
+
|
31
|
+
Arguments:
|
32
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
33
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
34
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
35
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
36
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
37
|
+
"""
|
38
|
+
super().__init__(**kwargs)
|
39
|
+
self._combiner = JsonCombiner()
|
40
|
+
self._load_parameters()
|
41
|
+
|
42
|
+
|
43
|
+
class RequirementEvaluator(Evaluator):
|
44
|
+
"""INCOSE Requirement Evaluator
|
45
|
+
|
46
|
+
A class that performs an LLM self evaluation on an input target,
|
47
|
+
with an associated prompt.
|
48
|
+
|
49
|
+
The evaluation prompts are for Incose Evaluations
|
50
|
+
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self, eval_items_per_request: int | None = None, **kwargs) -> None:
|
54
|
+
"""Initialize the Evaluator class
|
55
|
+
|
56
|
+
Arguments:
|
57
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
58
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
59
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
60
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
61
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
62
|
+
"""
|
63
|
+
super().__init__(**kwargs)
|
64
|
+
self.eval_items_per_request = eval_items_per_request
|
65
|
+
self._parser = IncoseParser()
|
66
|
+
self.set_prompt("eval_prompts/incose")
|
67
|
+
|
68
|
+
def _input_runnable(self) -> Runnable:
|
69
|
+
def _get_code(json_text: str) -> str:
|
70
|
+
return json.loads(json_text)["code"]
|
71
|
+
|
72
|
+
def _get_reqs(json_text: str) -> str:
|
73
|
+
return json.dumps(json.loads(json_text)["requirements"])
|
74
|
+
|
75
|
+
return RunnableLambda(self._parser.parse_input) | RunnableParallel(
|
76
|
+
SOURCE_CODE=_get_code,
|
77
|
+
REQUIREMENTS=_get_reqs,
|
78
|
+
context=self._retriever,
|
79
|
+
)
|
80
|
+
|
81
|
+
def _add_translation(self, block: TranslatedCodeBlock):
|
82
|
+
if block.translated:
|
83
|
+
return
|
84
|
+
|
85
|
+
if block.original.text is None:
|
86
|
+
block.translated = True
|
87
|
+
return
|
88
|
+
|
89
|
+
if self.eval_items_per_request is None:
|
90
|
+
return super()._add_translation(block)
|
91
|
+
|
92
|
+
input_obj = json.loads(block.original.text)
|
93
|
+
requirements = input_obj.get("requirements", [])
|
94
|
+
|
95
|
+
if not requirements:
|
96
|
+
log.debug(f"[{block.name}] Skipping empty block")
|
97
|
+
block.translated = True
|
98
|
+
block.text = None
|
99
|
+
block.complete = True
|
100
|
+
return
|
101
|
+
|
102
|
+
# For some reason requirements objects are in nested lists?
|
103
|
+
while isinstance(requirements[0], list):
|
104
|
+
requirements = [r for lst in requirements for r in lst]
|
105
|
+
|
106
|
+
if len(requirements) <= self.eval_items_per_request:
|
107
|
+
input_obj["requirements"] = requirements
|
108
|
+
block.original.text = json.dumps(input_obj)
|
109
|
+
return super()._add_translation(block)
|
110
|
+
|
111
|
+
block.processing_time = 0
|
112
|
+
block.cost = 0
|
113
|
+
block.retries = 0
|
114
|
+
obj = {}
|
115
|
+
for i in range(0, len(requirements), self.eval_items_per_request):
|
116
|
+
# Build a new TranslatedBlock using the new working text
|
117
|
+
working_requirements = requirements[i : i + self.eval_items_per_request]
|
118
|
+
working_copy = deepcopy(block.original)
|
119
|
+
working_obj = json.loads(working_copy.text) # type: ignore
|
120
|
+
working_obj["requirements"] = working_requirements
|
121
|
+
working_copy.text = json.dumps(working_obj)
|
122
|
+
working_block = TranslatedCodeBlock(working_copy, self._target_language)
|
123
|
+
|
124
|
+
# Run the LLM on the working text
|
125
|
+
super()._add_translation(working_block)
|
126
|
+
|
127
|
+
# Update metadata to include for all runs
|
128
|
+
block.retries += working_block.retries
|
129
|
+
block.cost += working_block.cost
|
130
|
+
block.processing_time += working_block.processing_time
|
131
|
+
|
132
|
+
# Update the output text to merge this section's output in
|
133
|
+
obj.update(json.loads(working_block.text))
|
134
|
+
|
135
|
+
block.text = json.dumps(obj)
|
136
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
137
|
+
block.translated = True
|
138
|
+
|
139
|
+
log.debug(
|
140
|
+
f"[{block.name}] Output code:\n{json.dumps(json.loads(block.text), indent=2)}"
|
141
|
+
)
|
142
|
+
|
143
|
+
|
144
|
+
class InlineCommentEvaluator(Evaluator):
|
145
|
+
"""Inline Comment Evaluator
|
146
|
+
|
147
|
+
A class that performs an LLM self evaluation on inline comments,
|
148
|
+
with an associated prompt.
|
149
|
+
"""
|
150
|
+
|
151
|
+
def __init__(self, eval_items_per_request: int | None = None, **kwargs) -> None:
|
152
|
+
"""Initialize the Evaluator class
|
153
|
+
|
154
|
+
Arguments:
|
155
|
+
model: The LLM to use for translation. If an OpenAI model, the
|
156
|
+
`OPENAI_API_KEY` environment variable must be set and the
|
157
|
+
`OPENAI_ORG_ID` environment variable should be set if needed.
|
158
|
+
model_arguments: Additional arguments to pass to the LLM constructor.
|
159
|
+
max_prompts: The maximum number of prompts to try before giving up.
|
160
|
+
"""
|
11
161
|
super().__init__(**kwargs)
|
12
|
-
self.set_prompt("evaluate")
|
13
162
|
self._combiner = JsonCombiner()
|
14
|
-
self._parser = EvaluationParser()
|
15
163
|
self._load_parameters()
|
164
|
+
self._parser = InlineCommentParser()
|
165
|
+
self.set_prompt("eval_prompts/inline_comments")
|
166
|
+
self.eval_items_per_request = eval_items_per_request
|
167
|
+
|
168
|
+
def _add_translation(self, block: TranslatedCodeBlock):
|
169
|
+
if block.translated:
|
170
|
+
return
|
171
|
+
|
172
|
+
if block.original.text is None:
|
173
|
+
block.translated = True
|
174
|
+
return
|
175
|
+
|
176
|
+
if self.eval_items_per_request is None:
|
177
|
+
return super()._add_translation(block)
|
178
|
+
|
179
|
+
comment_pattern = r"<(?:INLINE|BLOCK)_COMMENT \w{8}>.*$"
|
180
|
+
comments = list(
|
181
|
+
re.finditer(comment_pattern, block.original.text, flags=re.MULTILINE)
|
182
|
+
)
|
183
|
+
|
184
|
+
if not comments:
|
185
|
+
log.info(f"[{block.name}] Skipping commentless block")
|
186
|
+
block.translated = True
|
187
|
+
block.text = None
|
188
|
+
block.complete = True
|
189
|
+
return
|
190
|
+
|
191
|
+
if len(comments) <= self.eval_items_per_request:
|
192
|
+
return super()._add_translation(block)
|
193
|
+
|
194
|
+
comment_group_indices = list(range(0, len(comments), self.eval_items_per_request))
|
195
|
+
log.debug(
|
196
|
+
f"[{block.name}] Block contains more than {self.eval_items_per_request}"
|
197
|
+
f" comments, splitting {len(comments)} comments into"
|
198
|
+
f" {len(comment_group_indices)} groups"
|
199
|
+
)
|
200
|
+
|
201
|
+
block.processing_time = 0
|
202
|
+
block.cost = 0
|
203
|
+
block.retries = 0
|
204
|
+
obj = {}
|
205
|
+
for i in range(0, len(comments), self.eval_items_per_request):
|
206
|
+
# Split the text into the section containing comments of interest,
|
207
|
+
# all the text prior to those comments, and all the text after them
|
208
|
+
working_comments = comments[i : i + self.eval_items_per_request]
|
209
|
+
start_idx = working_comments[0].start()
|
210
|
+
end_idx = working_comments[-1].end()
|
211
|
+
prefix = block.original.text[:start_idx]
|
212
|
+
keeper = block.original.text[start_idx:end_idx]
|
213
|
+
suffix = block.original.text[end_idx:]
|
214
|
+
|
215
|
+
# Strip all comment placeholders outside of the section of interest
|
216
|
+
prefix = re.sub(comment_pattern, "", prefix, flags=re.MULTILINE)
|
217
|
+
suffix = re.sub(comment_pattern, "", suffix, flags=re.MULTILINE)
|
218
|
+
|
219
|
+
# Build a new TranslatedBlock using the new working text
|
220
|
+
working_copy = deepcopy(block.original)
|
221
|
+
working_copy.text = prefix + keeper + suffix
|
222
|
+
working_block = TranslatedCodeBlock(working_copy, self._target_language)
|
223
|
+
|
224
|
+
# Run the LLM on the working text
|
225
|
+
super()._add_translation(working_block)
|
226
|
+
|
227
|
+
# Update metadata to include for all runs
|
228
|
+
block.retries += working_block.retries
|
229
|
+
block.cost += working_block.cost
|
230
|
+
block.processing_time += working_block.processing_time
|
231
|
+
|
232
|
+
# Update the output text to merge this section's output in
|
233
|
+
obj.update(json.loads(working_block.text))
|
234
|
+
|
235
|
+
block.text = json.dumps(obj)
|
236
|
+
block.tokens = self._llm.get_num_tokens(block.text)
|
237
|
+
block.translated = True
|
238
|
+
|
239
|
+
log.debug(
|
240
|
+
f"[{block.name}] Output code:\n{json.dumps(json.loads(block.text), indent=2)}"
|
241
|
+
)
|
@@ -20,7 +20,7 @@ class TestAlcSplitter(unittest.TestCase):
|
|
20
20
|
def test_split(self):
|
21
21
|
"""Test the split method."""
|
22
22
|
tree_root = self.splitter.split(self.test_file)
|
23
|
-
self.assertAlmostEqual(tree_root.n_descendents,
|
23
|
+
self.assertAlmostEqual(tree_root.n_descendents, 16, delta=2)
|
24
24
|
self.assertLessEqual(tree_root.max_tokens, self.splitter.max_tokens)
|
25
25
|
self.assertFalse(tree_root.complete)
|
26
26
|
self.combiner.combine_children(tree_root)
|
janus/language/alc/alc.py
CHANGED
@@ -79,10 +79,15 @@ class AlcSplitter(TreeSitterSplitter):
|
|
79
79
|
if len(sects) > 1:
|
80
80
|
block.children = []
|
81
81
|
for sect in sects:
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
82
|
+
node_type = sect[0].node_type
|
83
|
+
if node_type in sect_types:
|
84
|
+
if len(sect) == 1:
|
85
|
+
# Don't make a node its own child
|
86
|
+
sect_node = sect[0]
|
87
|
+
else:
|
88
|
+
sect_node = self.merge_nodes(sect)
|
89
|
+
sect_node.children = sect
|
90
|
+
sect_node.node_type = NodeType(str(node_type)[:5])
|
86
91
|
block.children.append(sect_node)
|
87
92
|
else:
|
88
93
|
block.children.extend(sect)
|
janus/language/splitter.py
CHANGED
@@ -275,42 +275,50 @@ class Splitter(FileManager):
|
|
275
275
|
|
276
276
|
groups = [[n] for n in nodes]
|
277
277
|
while len(groups) > 1 and min(adj_sums) <= self.max_tokens and any(merge_allowed):
|
278
|
-
# Get the
|
279
|
-
#
|
278
|
+
# Get the index of the node that would result in the smallest
|
279
|
+
# merged snippet when merged with the node that follows it.
|
280
|
+
# Ignore protected nodes.
|
280
281
|
mergeable_indices = compress(range(len(adj_sums)), merge_allowed)
|
281
|
-
|
282
|
-
|
282
|
+
C = int(min(mergeable_indices, key=adj_sums.__getitem__))
|
283
|
+
|
284
|
+
# C: Central index
|
285
|
+
# L: Index to the left
|
286
|
+
# R: Index to the right (to be merged in to C)
|
287
|
+
# N: Next index (to the right of R, the "new R")
|
288
|
+
L, R, N = C - 1, C + 1, C + 2
|
283
289
|
|
284
290
|
# Recalculate the length. We can't simply use the adj_sum, because
|
285
291
|
# it is an underestimate due to the adjoining suffix/prefix.
|
286
|
-
central_node = groups[
|
287
|
-
merged_text = "".join([text_chunks[
|
292
|
+
central_node = groups[C][-1]
|
293
|
+
merged_text = "".join([text_chunks[C], central_node.suffix, text_chunks[R]])
|
288
294
|
merged_text_length = self._count_tokens(merged_text)
|
289
295
|
|
290
296
|
# If the true length of the merged pair is too long, don't merge them
|
291
297
|
# Instead, correct the estimate, since shorter pairs may yet exist
|
292
298
|
if merged_text_length > self.max_tokens:
|
293
|
-
adj_sums[
|
299
|
+
adj_sums[C] = merged_text_length
|
294
300
|
continue
|
295
301
|
|
296
302
|
# Update adjacent sum estimates
|
297
|
-
if
|
298
|
-
adj_sums[
|
299
|
-
if
|
300
|
-
adj_sums[
|
301
|
-
|
302
|
-
if i0 > 0 and i1 < len(merge_allowed) - 1:
|
303
|
-
if not (merge_allowed[i0 - 1] and merge_allowed[i1 + 1]):
|
304
|
-
merge_allowed[i0 - 1] = merge_allowed[i1 + 1] = False
|
303
|
+
if L >= 0:
|
304
|
+
adj_sums[L] = lengths[L] + merged_text_length
|
305
|
+
if N < len(adj_sums):
|
306
|
+
adj_sums[R] = lengths[N] + merged_text_length
|
305
307
|
|
306
308
|
# The potential merge length for this pair is removed
|
307
|
-
adj_sums.pop(
|
308
|
-
|
309
|
+
adj_sums.pop(C)
|
310
|
+
|
311
|
+
# The merged-in node is removed from the protected list
|
312
|
+
# The merge_allowed list need not be updated - if the node now to
|
313
|
+
# its right is protected, the merge_allowed element corresponding
|
314
|
+
# to the merged neighbor will have been True, and now corresponds
|
315
|
+
# to the merged node.
|
316
|
+
merge_allowed.pop(C)
|
309
317
|
|
310
318
|
# Merge the pair of node groups
|
311
|
-
groups[
|
312
|
-
text_chunks[
|
313
|
-
lengths[
|
319
|
+
groups[C:N] = [groups[C] + groups[R]]
|
320
|
+
text_chunks[C:N] = [merged_text]
|
321
|
+
lengths[C:N] = [merged_text_length]
|
314
322
|
|
315
323
|
return groups
|
316
324
|
|
@@ -403,13 +411,13 @@ class Splitter(FileManager):
|
|
403
411
|
self._split_into_lines(node)
|
404
412
|
|
405
413
|
def _split_into_lines(self, node: CodeBlock):
|
406
|
-
split_text = re.split(r"(\n+)", node.text)
|
414
|
+
split_text = list(re.split(r"(\n+)", node.text))
|
407
415
|
|
408
416
|
# If the string didn't start/end with newlines, make sure to include
|
409
417
|
# empty strings for the prefix/suffixes
|
410
|
-
if
|
418
|
+
if not re.match(r"^\n+$", split_text[0]):
|
411
419
|
split_text = [""] + split_text
|
412
|
-
if split_text[-1]
|
420
|
+
if not re.match(r"^\n+$", split_text[-1]):
|
413
421
|
split_text.append("")
|
414
422
|
betweens = split_text[::2]
|
415
423
|
lines = split_text[1::2]
|
@@ -154,7 +154,15 @@ class TreeSitterSplitter(Splitter):
|
|
154
154
|
The pointer to the language.
|
155
155
|
"""
|
156
156
|
lib = cdll.LoadLibrary(os.fspath(so_file))
|
157
|
-
|
157
|
+
# Added this try-except block to handle the case where the language is not
|
158
|
+
# supported in lowercase by the creator of the grammar. Ex: COBOL
|
159
|
+
# https://github.com/yutaro-sakamoto/tree-sitter-cobol/blob/main/grammar.js#L13
|
160
|
+
try:
|
161
|
+
language_function = getattr(lib, f"tree_sitter_{self.language}")
|
162
|
+
except AttributeError:
|
163
|
+
language = self.language.upper()
|
164
|
+
language_function = getattr(lib, f"tree_sitter_{language}")
|
165
|
+
|
158
166
|
language_function.restype = c_void_p
|
159
167
|
pointer = language_function()
|
160
168
|
return pointer
|
janus/llm/models_info.py
CHANGED
@@ -6,9 +6,13 @@ from typing import Callable, Protocol, TypeVar
|
|
6
6
|
from dotenv import load_dotenv
|
7
7
|
from langchain_community.llms import HuggingFaceTextGenInference
|
8
8
|
from langchain_core.runnables import Runnable
|
9
|
-
from langchain_openai import AzureChatOpenAI
|
9
|
+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
10
10
|
|
11
|
-
from janus.llm.model_callbacks import
|
11
|
+
from janus.llm.model_callbacks import (
|
12
|
+
COST_PER_1K_TOKENS,
|
13
|
+
azure_model_reroutes,
|
14
|
+
openai_model_reroutes,
|
15
|
+
)
|
12
16
|
from janus.prompts.prompt import (
|
13
17
|
ChatGptPromptEngine,
|
14
18
|
ClaudePromptEngine,
|
@@ -127,7 +131,7 @@ bedrock_models = [
|
|
127
131
|
all_models = [*azure_models, *bedrock_models]
|
128
132
|
|
129
133
|
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
130
|
-
|
134
|
+
"OpenAI": ChatOpenAI,
|
131
135
|
"HuggingFace": HuggingFaceTextGenInference,
|
132
136
|
"Azure": AzureChatOpenAI,
|
133
137
|
"Bedrock": Bedrock,
|
@@ -137,7 +141,7 @@ MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
|
137
141
|
|
138
142
|
|
139
143
|
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
140
|
-
|
144
|
+
**{m: ChatGptPromptEngine for m in openai_models},
|
141
145
|
**{m: ChatGptPromptEngine for m in azure_models},
|
142
146
|
**{m: ClaudePromptEngine for m in claude_models},
|
143
147
|
**{m: Llama2PromptEngine for m in llama2_models},
|
@@ -148,7 +152,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
|
148
152
|
}
|
149
153
|
|
150
154
|
MODEL_ID_TO_LONG_ID = {
|
151
|
-
|
155
|
+
**{m: mr for m, mr in openai_model_reroutes.items()},
|
152
156
|
**{m: mr for m, mr in azure_model_reroutes.items()},
|
153
157
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
154
158
|
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
@@ -181,7 +185,7 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
181
185
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
182
186
|
|
183
187
|
MODEL_TYPES: dict[str, PromptEngine] = {
|
184
|
-
|
188
|
+
**{m: "OpenAI" for m in openai_models},
|
185
189
|
**{m: "Azure" for m in azure_models},
|
186
190
|
**{m: "BedrockChat" for m in bedrock_models},
|
187
191
|
}
|
@@ -289,15 +293,16 @@ def load_model(model_id) -> JanusModel:
|
|
289
293
|
# log.warning("Waiting 10 seconds...")
|
290
294
|
# Give enough time for the user to read the warnings and cancel
|
291
295
|
# time.sleep(10)
|
292
|
-
raise DeprecationWarning("OpenAI models are no longer supported.")
|
296
|
+
# raise DeprecationWarning("OpenAI models are no longer supported.")
|
293
297
|
|
294
298
|
elif model_type_name == "Azure":
|
295
299
|
model_args.update(
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
300
|
+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
301
|
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
302
|
+
api_version=os.getenv("OPENAI_API_VERSION", "2024-02-01"),
|
303
|
+
azure_deployment=model_id,
|
304
|
+
request_timeout=3600,
|
305
|
+
max_tokens=4096,
|
301
306
|
)
|
302
307
|
|
303
308
|
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
from langchain.output_parsers import PydanticOutputParser
|
7
|
+
from langchain_core.exceptions import OutputParserException
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, validator
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class Criteria(BaseModel):
|
20
|
+
reasoning: str = Field(description="A short explanation for the given assessment")
|
21
|
+
score: str = Field("A simple `pass` or `fail`")
|
22
|
+
|
23
|
+
@validator("score")
|
24
|
+
def score_is_valid(cls, v: str):
|
25
|
+
v = v.lower().strip()
|
26
|
+
if v not in {"pass", "fail"}:
|
27
|
+
raise OutputParserException("Score must be either 'pass' or 'fail'")
|
28
|
+
return v
|
29
|
+
|
30
|
+
|
31
|
+
class Requirement(BaseModel):
|
32
|
+
requirement_id: str = Field(description="The 8-character comment ID")
|
33
|
+
requirement: str = Field(description="The original requirement being evaluated")
|
34
|
+
C1: Criteria
|
35
|
+
C2: Criteria
|
36
|
+
C3: Criteria
|
37
|
+
C4: Criteria
|
38
|
+
C5: Criteria
|
39
|
+
C6: Criteria
|
40
|
+
C7: Criteria
|
41
|
+
C8: Criteria
|
42
|
+
C9: Criteria
|
43
|
+
|
44
|
+
|
45
|
+
class RequirementList(BaseModel):
|
46
|
+
__root__: List[Requirement] = Field(
|
47
|
+
description=(
|
48
|
+
"A list of requirement evaluations. Each element should include"
|
49
|
+
" the requirement's 8-character ID in the `requirement_id` field,"
|
50
|
+
" the original requirement in the 'requirement' field, "
|
51
|
+
" and nine score objects corresponding to each criterion."
|
52
|
+
)
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class IncoseParser(JanusParser, PydanticOutputParser):
|
57
|
+
requirements: dict[str, str]
|
58
|
+
|
59
|
+
def __init__(self):
|
60
|
+
PydanticOutputParser.__init__(
|
61
|
+
self,
|
62
|
+
pydantic_object=RequirementList,
|
63
|
+
requirements={},
|
64
|
+
)
|
65
|
+
|
66
|
+
def parse_input(self, block: CodeBlock) -> str:
|
67
|
+
# TODO: Perform comment stripping/placeholding here rather than in script
|
68
|
+
text = super().parse_input(block)
|
69
|
+
RNG.seed(text)
|
70
|
+
|
71
|
+
obj = json.loads(text)
|
72
|
+
|
73
|
+
# For some reason requirements objects are in a double list?
|
74
|
+
reqs = obj["requirements"]
|
75
|
+
|
76
|
+
# Generate a unique ID for each requirement (ensure they are unique)
|
77
|
+
req_ids = set()
|
78
|
+
while len(req_ids) < len(reqs):
|
79
|
+
req_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
|
80
|
+
|
81
|
+
self.requirements = dict(zip(req_ids, reqs))
|
82
|
+
reqs_str = "\n\n".join(
|
83
|
+
f"Requirement {rid} : {req}" for rid, req in self.requirements.items()
|
84
|
+
)
|
85
|
+
obj["requirements"] = reqs_str
|
86
|
+
return json.dumps(obj)
|
87
|
+
|
88
|
+
def parse(self, text: str | BaseMessage) -> str:
|
89
|
+
if isinstance(text, BaseMessage):
|
90
|
+
text = str(text.content)
|
91
|
+
|
92
|
+
# Strip everything outside the JSON object
|
93
|
+
begin, end = text.find("["), text.rfind("]")
|
94
|
+
text = text[begin : end + 1]
|
95
|
+
|
96
|
+
try:
|
97
|
+
out: RequirementList = super().parse(text)
|
98
|
+
except json.JSONDecodeError as e:
|
99
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
100
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
101
|
+
|
102
|
+
evals: dict[str, dict] = {c.requirement_id: c.dict() for c in out.__root__}
|
103
|
+
|
104
|
+
seen_keys = set(evals.keys())
|
105
|
+
expected_keys = set(self.requirements.keys())
|
106
|
+
missing_keys = expected_keys.difference(seen_keys)
|
107
|
+
invalid_keys = seen_keys.difference(expected_keys)
|
108
|
+
if missing_keys:
|
109
|
+
log.debug(f"Missing keys: {missing_keys}")
|
110
|
+
if invalid_keys:
|
111
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
112
|
+
log.debug(f"Missing keys: {missing_keys}")
|
113
|
+
raise OutputParserException(
|
114
|
+
f"Got invalid return object. Missing the following expected "
|
115
|
+
f"keys: {missing_keys}"
|
116
|
+
)
|
117
|
+
|
118
|
+
for key in invalid_keys:
|
119
|
+
del evals[key]
|
120
|
+
|
121
|
+
for rid in evals.keys():
|
122
|
+
evals[rid]["requirement"] = self.requirements[rid]
|
123
|
+
evals[rid].pop("requirement_id")
|
124
|
+
|
125
|
+
return json.dumps(evals)
|
126
|
+
|
127
|
+
def parse_combined_output(self, text: str) -> str:
|
128
|
+
if not text.strip():
|
129
|
+
return str({})
|
130
|
+
objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
|
131
|
+
output_obj = {}
|
132
|
+
for obj in objs:
|
133
|
+
output_obj.update(obj)
|
134
|
+
return json.dumps(output_obj)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.messages import BaseMessage
|
8
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, conint
|
9
|
+
|
10
|
+
from janus.language.block import CodeBlock
|
11
|
+
from janus.parsers.parser import JanusParser
|
12
|
+
from janus.utils.logger import create_logger
|
13
|
+
|
14
|
+
log = create_logger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class Criteria(BaseModel):
|
18
|
+
reasoning: str = Field(description="A short explanation for the given score")
|
19
|
+
# Constrained to an integer between 1 and 4
|
20
|
+
score: conint(ge=1, le=4) = Field( # type: ignore
|
21
|
+
description="An integer score between 1 and 4 (inclusive), 4 being the best"
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
class Comment(BaseModel):
|
26
|
+
comment_id: str = Field(description="The 8-character comment ID")
|
27
|
+
completeness: Criteria = Field(description="The completeness of the comment")
|
28
|
+
hallucination: Criteria = Field(description="The factualness of the comment")
|
29
|
+
readability: Criteria = Field(description="The readability of the comment")
|
30
|
+
usefulness: Criteria = Field(description="The usefulness of the comment")
|
31
|
+
|
32
|
+
|
33
|
+
class CommentList(BaseModel):
|
34
|
+
__root__: list[Comment] = Field(
|
35
|
+
description=(
|
36
|
+
"A list of inline comment evaluations. Each element should include"
|
37
|
+
" the comment's 8-character ID in the `comment_id` field, and four"
|
38
|
+
" score objects corresponding to each metric (`completeness`,"
|
39
|
+
" `hallucination`, `readability`, and `usefulness`)."
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class InlineCommentParser(JanusParser, PydanticOutputParser):
|
45
|
+
comments: dict[str, str]
|
46
|
+
|
47
|
+
def __init__(self):
|
48
|
+
PydanticOutputParser.__init__(
|
49
|
+
self,
|
50
|
+
pydantic_object=CommentList,
|
51
|
+
comments=[],
|
52
|
+
)
|
53
|
+
|
54
|
+
def parse_input(self, block: CodeBlock) -> str:
|
55
|
+
# TODO: Perform comment stripping/placeholding here rather than in script
|
56
|
+
text = super().parse_input(block)
|
57
|
+
self.comments = dict(
|
58
|
+
re.findall(
|
59
|
+
r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
|
60
|
+
text,
|
61
|
+
flags=re.MULTILINE,
|
62
|
+
)
|
63
|
+
)
|
64
|
+
return text
|
65
|
+
|
66
|
+
def parse(self, text: str | BaseMessage) -> str:
|
67
|
+
if isinstance(text, BaseMessage):
|
68
|
+
text = str(text.content)
|
69
|
+
|
70
|
+
# Strip everything outside the JSON object
|
71
|
+
begin, end = text.find("["), text.rfind("]")
|
72
|
+
text = text[begin : end + 1]
|
73
|
+
|
74
|
+
try:
|
75
|
+
out: CommentList = super().parse(text)
|
76
|
+
except json.JSONDecodeError as e:
|
77
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
78
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
79
|
+
|
80
|
+
evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
|
81
|
+
|
82
|
+
seen_keys = set(evals.keys())
|
83
|
+
expected_keys = set(self.comments.keys())
|
84
|
+
missing_keys = expected_keys.difference(seen_keys)
|
85
|
+
invalid_keys = seen_keys.difference(expected_keys)
|
86
|
+
if missing_keys:
|
87
|
+
log.debug(f"Missing keys: {missing_keys}")
|
88
|
+
if invalid_keys:
|
89
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
90
|
+
log.debug(f"Missing keys: {missing_keys}")
|
91
|
+
raise OutputParserException(
|
92
|
+
f"Got invalid return object. Missing the following expected "
|
93
|
+
f"keys: {missing_keys}"
|
94
|
+
)
|
95
|
+
|
96
|
+
for key in invalid_keys:
|
97
|
+
del evals[key]
|
98
|
+
|
99
|
+
for cid in evals.keys():
|
100
|
+
evals[cid]["comment"] = self.comments[cid]
|
101
|
+
evals[cid].pop("comment_id")
|
102
|
+
|
103
|
+
return json.dumps(evals)
|
104
|
+
|
105
|
+
def parse_combined_output(self, text: str) -> str:
|
106
|
+
if not text.strip():
|
107
|
+
return str({})
|
108
|
+
objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
|
109
|
+
output_obj = {}
|
110
|
+
for obj in objs:
|
111
|
+
output_obj.update(obj)
|
112
|
+
return json.dumps(output_obj)
|
@@ -36,6 +36,29 @@ class PartitionList(BaseModel):
|
|
36
36
|
)
|
37
37
|
|
38
38
|
|
39
|
+
# The following IDs appear in the prompt example. If the LLM produces them,
|
40
|
+
# they should be ignored
|
41
|
+
EXAMPLE_IDS = {
|
42
|
+
"0d2f4f8d",
|
43
|
+
"def2a953",
|
44
|
+
"75315253",
|
45
|
+
"e7f928da",
|
46
|
+
"1781b2a9",
|
47
|
+
"2fe21e27",
|
48
|
+
"9aef6179",
|
49
|
+
"6061bd82",
|
50
|
+
"22bd0c30",
|
51
|
+
"5d85e19e",
|
52
|
+
"06027969",
|
53
|
+
"91b722fb",
|
54
|
+
"4b3f79be",
|
55
|
+
"k57w964a",
|
56
|
+
"51638s96",
|
57
|
+
"065o6q32",
|
58
|
+
"j5q6p852",
|
59
|
+
}
|
60
|
+
|
61
|
+
|
39
62
|
class PartitionParser(JanusParser, PydanticOutputParser):
|
40
63
|
token_limit: int
|
41
64
|
model: BaseLanguageModel
|
@@ -59,7 +82,10 @@ class PartitionParser(JanusParser, PydanticOutputParser):
|
|
59
82
|
# Generate a unique ID for each line (ensure they are unique)
|
60
83
|
line_ids = set()
|
61
84
|
while len(line_ids) < len(self.lines):
|
62
|
-
|
85
|
+
line_id = str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8]
|
86
|
+
if line_id in EXAMPLE_IDS:
|
87
|
+
continue
|
88
|
+
line_ids.add(line_id)
|
63
89
|
|
64
90
|
# Prepend each line with the corresponding ID, save the mapping
|
65
91
|
self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
|
@@ -72,18 +98,24 @@ class PartitionParser(JanusParser, PydanticOutputParser):
|
|
72
98
|
if isinstance(text, BaseMessage):
|
73
99
|
text = str(text.content)
|
74
100
|
|
101
|
+
# Strip everything outside the JSON object
|
102
|
+
begin, end = text.find("["), text.rfind("]")
|
103
|
+
text = text[begin : end + 1]
|
104
|
+
|
75
105
|
try:
|
76
106
|
out: PartitionList = super().parse(text)
|
77
107
|
except (OutputParserException, json.JSONDecodeError):
|
78
108
|
log.debug(f"Invalid JSON object. Output:\n{text}")
|
79
109
|
raise
|
80
110
|
|
111
|
+
# Get partition locations, discard reasoning
|
112
|
+
partition_locations = {partition.location for partition in out.__root__}
|
113
|
+
|
114
|
+
# Ignore IDs from the example input
|
115
|
+
partition_locations.difference_update(EXAMPLE_IDS)
|
116
|
+
|
81
117
|
# Locate any invalid line IDs, raise exception if any found
|
82
|
-
invalid_splits =
|
83
|
-
partition.location
|
84
|
-
for partition in out.__root__
|
85
|
-
if partition.location not in self.line_id_to_index
|
86
|
-
]
|
118
|
+
invalid_splits = partition_locations.difference(self.line_id_to_index)
|
87
119
|
if invalid_splits:
|
88
120
|
err_msg = (
|
89
121
|
f"{len(invalid_splits)} line ID(s) not found in input: "
|
@@ -95,9 +127,9 @@ class PartitionParser(JanusParser, PydanticOutputParser):
|
|
95
127
|
# Map line IDs to indices (so they can be sorted and lines indexed)
|
96
128
|
index_to_line_id = {0: "START", None: "END"}
|
97
129
|
split_points = {0}
|
98
|
-
for partition in
|
99
|
-
index = self.line_id_to_index[partition
|
100
|
-
index_to_line_id[index] = partition
|
130
|
+
for partition in partition_locations:
|
131
|
+
index = self.line_id_to_index[partition]
|
132
|
+
index_to_line_id[index] = partition
|
101
133
|
split_points.add(index)
|
102
134
|
|
103
135
|
# Get partition start/ends, chunks, chunk lengths
|
janus/refiners/refiner.py
CHANGED
@@ -2,6 +2,7 @@ import re
|
|
2
2
|
from typing import Any
|
3
3
|
|
4
4
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
5
|
+
from langchain_core.exceptions import OutputParserException
|
5
6
|
from langchain_core.output_parsers import StrOutputParser
|
6
7
|
from langchain_core.prompt_values import PromptValue
|
7
8
|
from langchain_core.runnables import RunnableSerializable
|
@@ -26,6 +27,35 @@ class JanusRefiner(JanusParser):
|
|
26
27
|
raise NotImplementedError
|
27
28
|
|
28
29
|
|
30
|
+
class SimpleRetry(JanusRefiner):
|
31
|
+
max_retries: int
|
32
|
+
retry_chain: RunnableSerializable
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
llm: JanusModel,
|
37
|
+
parser: JanusParser,
|
38
|
+
max_retries: int,
|
39
|
+
):
|
40
|
+
retry_chain = llm | StrOutputParser()
|
41
|
+
super().__init__(
|
42
|
+
retry_chain=retry_chain,
|
43
|
+
parser=parser,
|
44
|
+
max_retries=max_retries,
|
45
|
+
)
|
46
|
+
|
47
|
+
def parse_completion(
|
48
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
49
|
+
) -> Any:
|
50
|
+
for retry_number in range(self.max_retries):
|
51
|
+
try:
|
52
|
+
return self.parser.parse(completion)
|
53
|
+
except OutputParserException:
|
54
|
+
completion = self.retry_chain.invoke(prompt_value)
|
55
|
+
|
56
|
+
return self.parser.parse(completion)
|
57
|
+
|
58
|
+
|
29
59
|
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
30
60
|
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
31
61
|
retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
janus/utils/enums.py
CHANGED
@@ -89,6 +89,20 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
89
89
|
"url": "https://github.com/stsewd/tree-sitter-comment",
|
90
90
|
"example": "# This is a comment\n",
|
91
91
|
},
|
92
|
+
"cobol": {
|
93
|
+
"comment": "*",
|
94
|
+
"suffix": "cbl",
|
95
|
+
"url": "https://github.com/yutaro-sakamoto/tree-sitter-cobol",
|
96
|
+
"example": (
|
97
|
+
" IDENTIFICATION DIVISION.\n"
|
98
|
+
" PROGRAM-ID. HelloWorld.\n"
|
99
|
+
" ENVIRONMENT DIVISION.\n"
|
100
|
+
" DATA DIVISION.\n"
|
101
|
+
" PROCEDURE DIVISION.\n"
|
102
|
+
' DISPLAY "Hello, World!".\n'
|
103
|
+
" STOP RUN.\n"
|
104
|
+
),
|
105
|
+
},
|
92
106
|
"commonlisp": {
|
93
107
|
"comment": ";;",
|
94
108
|
"suffix": "lisp",
|
@@ -1,17 +1,17 @@
|
|
1
|
-
janus/__init__.py,sha256=
|
1
|
+
janus/__init__.py,sha256=hbiNcSyVowLc5sEqV1GU1B22molrn1w3rOxtKlgrl2E,361
|
2
2
|
janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
|
3
3
|
janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
|
5
5
|
janus/_tests/test_cli.py,sha256=6ef7h11bg4i7Q6L1-r0ZdcY7YrH4n472kvDiA03T4c8,4275
|
6
|
-
janus/cli.py,sha256=
|
6
|
+
janus/cli.py,sha256=zo8EEp0Y33jPCzMUGGRXxjr629ZPMIrVGk3FxinpyDQ,46851
|
7
7
|
janus/converter/__init__.py,sha256=Jnp3TsJ4M1LWDAzXFSyxzMpygbYOxkR-qYxU-G6Gi1k,395
|
8
8
|
janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
janus/converter/_tests/test_translate.py,sha256=T5CzNrwHqJWfb39Izq84R9WvM3toSlJq31SeA_U7d_4,5641
|
10
10
|
janus/converter/aggregator.py,sha256=MuAXMKmq6PuUo_w6ljyiuDn81Gk2dN-Ci7FVeLc6vhs,1966
|
11
|
-
janus/converter/converter.py,sha256=
|
11
|
+
janus/converter/converter.py,sha256=citSpcCsI1bDfckK38smGNafDHsc8DC9quSoXD2J-Kc,26253
|
12
12
|
janus/converter/diagram.py,sha256=-wktVBPrSBgNIQfHIfa2bJNg6L9CYJQgrr9-xU8DFPw,1646
|
13
13
|
janus/converter/document.py,sha256=qNt2UncMheUBadXCFHGq74tqCrvZub5DCgZpd3Qa54o,4564
|
14
|
-
janus/converter/evaluate.py,sha256=
|
14
|
+
janus/converter/evaluate.py,sha256=Bdue1ESQfMVFFRK4l0CvqwLyzt5bqOKy1LB9a8Hqub0,9150
|
15
15
|
janus/converter/partition.py,sha256=ASvv4hAue44qHobO4kqr_tKr-eJsXCPPdD3NtNd9V-E,993
|
16
16
|
janus/converter/requirements.py,sha256=9tvQ40FZJtG8niIFn45gPQCgKKHVPPoFLinBv6RAqO4,2027
|
17
17
|
janus/converter/translate.py,sha256=S1DPZdmX9Vrn_sJPcobvXmhmS8U53yl5cRXjsmXPtas,4246
|
@@ -30,8 +30,8 @@ janus/language/_tests/test_combine.py,sha256=sjVVPUg4LYkAmazXGUw_S1xPrzWm67_0tCx
|
|
30
30
|
janus/language/_tests/test_splitter.py,sha256=Hqexa39LLEXlK3ZUw7Zot4PUIACvye2vkq0Jaox0T10,373
|
31
31
|
janus/language/alc/__init__.py,sha256=j7vOMGhT1Vri6p8dsjSaY-fkO5uFn0sJ0nrNGGvcizM,42
|
32
32
|
janus/language/alc/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
|
-
janus/language/alc/_tests/test_alc.py,sha256=
|
34
|
-
janus/language/alc/alc.py,sha256=
|
33
|
+
janus/language/alc/_tests/test_alc.py,sha256=8LKidOPJDlMonRBX9w8AVOKHhyR-O2srW4ntzw5rEEs,1018
|
34
|
+
janus/language/alc/alc.py,sha256=YteDO6DR5hnQULjI3j8Je-w05MH50ZARtXB66FqkZi4,7088
|
35
35
|
janus/language/binary/__init__.py,sha256=AlNAe12ZA366kcGSrQ1FJyOdbwxFqGBFkYR2K6yL818,51
|
36
36
|
janus/language/binary/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
37
|
janus/language/binary/_tests/test_binary.py,sha256=cIKIxjj6kIY3rcxLwqUPESP9bxWrHqMHx9TNuICgfeQ,1724
|
@@ -52,14 +52,14 @@ janus/language/naive/registry.py,sha256=8YQX1q0IdAm7t69-oC_00I-vfkdRnHuX-OD3KEjE
|
|
52
52
|
janus/language/naive/simple_ast.py,sha256=YzeUJomVsnttJc8tI9eDROb2Hx9Vm9XKmOnLEp3TkzI,3112
|
53
53
|
janus/language/naive/tag_splitter.py,sha256=IXWMn9tBVUGAtzvQi89GhoZ6g7fPXk5MzO0kMCr2mb0,2045
|
54
54
|
janus/language/node.py,sha256=baoYFtapwBQqBtUN6EvHFYRkbR-EcEw1b3fQvH9zIAM,204
|
55
|
-
janus/language/splitter.py,sha256=
|
55
|
+
janus/language/splitter.py,sha256=ZpNIzv0ijbcH7EMnY8DIxAf0ji7-ym1iYJXS9ei_F78,17389
|
56
56
|
janus/language/treesitter/__init__.py,sha256=mUliw7ZJLZ8NkJKyUQMSoUV82hYXE0HvLHrEdGPJF4Q,43
|
57
57
|
janus/language/treesitter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
janus/language/treesitter/_tests/test_treesitter.py,sha256=fmr_mFSja7vaCVu0TVyLDua3A94jMjY4AqSC5NqnOdQ,2179
|
59
|
-
janus/language/treesitter/treesitter.py,sha256=
|
59
|
+
janus/language/treesitter/treesitter.py,sha256=FdsBO8CEo6l9D77aHXns5jRSoZzkvrRGZFCW3oNw15c,7928
|
60
60
|
janus/llm/__init__.py,sha256=TKLYvnsWKWfxMucy-lCLQ-4bkN9ENotJZDywDEQmrKg,45
|
61
61
|
janus/llm/model_callbacks.py,sha256=cHRZBpYgAwiYbA2k0GQ7DBwBFQZJpEGMUBV3Q_5GTpU,7940
|
62
|
-
janus/llm/models_info.py,sha256=
|
62
|
+
janus/llm/models_info.py,sha256=tHH5Hf7zWBpD5zSuhxx_Tp1fQMPTKPr9EuevacDiUTU,10711
|
63
63
|
janus/metrics/__init__.py,sha256=AsxtZJUzZiXJPr2ehPPltuYP-ddechjg6X85WZUO7mA,241
|
64
64
|
janus/metrics/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
65
65
|
janus/metrics/_tests/reference.py,sha256=hiaJPP9CXkvFBV_wL-gOe_BzELTw0nvB6uCxhxtIiE8,13
|
@@ -89,25 +89,27 @@ janus/parsers/_tests/test_code_parser.py,sha256=3ay5QpUPcynX_EJ-YLl3PR28poutUkT7
|
|
89
89
|
janus/parsers/code_parser.py,sha256=3l0HfzgrvJuiwk779s9ZsgUl3xbp1nE1qZxh8aDYRBI,873
|
90
90
|
janus/parsers/doc_parser.py,sha256=0pUsNZ9hKQLjIi8L8BgkOBHQZ_EGoFLHrBQ4hoDkjSw,5862
|
91
91
|
janus/parsers/eval_parser.py,sha256=Gjh6aTZgpYd2ASJUEPMo4LpCL00cBmbOqc4KM3hy8x8,2922
|
92
|
+
janus/parsers/eval_parsers/incose_parser.py,sha256=udyK-24ocfrB1SzmggcERm73dBynrCj4MFSBV8k7YDM,4478
|
93
|
+
janus/parsers/eval_parsers/inline_comment_parser.py,sha256=QzKgzeWPhyIEkLxJBpeutSocSJjjXEcWRRS635bXEO8,3973
|
92
94
|
janus/parsers/parser.py,sha256=y6VV64bgVidf-oEFla3I--_28tnJsPBc6QUD_SkbfSE,1614
|
93
|
-
janus/parsers/partition_parser.py,sha256=
|
95
|
+
janus/parsers/partition_parser.py,sha256=IW5_aNYL4g-PzB_qJ0g0NlwLiaAGGewR5iUYF19PVL4,5738
|
94
96
|
janus/parsers/reqs_parser.py,sha256=uRQC41Iqp22GjIvakb5UKv70UWHkcOTbOVl_RDnipYw,2438
|
95
97
|
janus/parsers/uml.py,sha256=SwaoG9QrHKQP8rSxlf3qu_rp7OMQqYSmLgDYBapOa9M,3379
|
96
98
|
janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
97
99
|
janus/prompts/prompt.py,sha256=3796YXIzzIec9b0iUzd8VZlq-AdQbzq8qUGXLy4KH-0,10586
|
98
|
-
janus/refiners/refiner.py,sha256=
|
100
|
+
janus/refiners/refiner.py,sha256=ZHP0hUIv8eLpHJSd2SP1Sex6q6SdJgH7HIPgXPBw_gI,4672
|
99
101
|
janus/refiners/uml.py,sha256=ZFvFLxOdbolYuOmZh_8K6kiHCWKuudqP71sr_TammxM,866
|
100
102
|
janus/retrievers/retriever.py,sha256=n6MzoNZs0GJCH4eqQPS3gFlVHZ3eETr7FuHYbyPzTuo,3506
|
101
103
|
janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
102
104
|
janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
103
105
|
janus/utils/_tests/test_logger.py,sha256=jkkvrCTKwsFCsZtmyuvc-WJ0rC7LJi2Z91sIe4IiKzA,2209
|
104
106
|
janus/utils/_tests/test_progress.py,sha256=Rs_u5PiGjP-L-o6C1fhwfE1ig8jYu9Xo9s4p8yPysl8,491
|
105
|
-
janus/utils/enums.py,sha256=
|
107
|
+
janus/utils/enums.py,sha256=gmvX3MYnHAwu4ZypidENIZ27M5NI_YegY3PpCDJS34Q,28094
|
106
108
|
janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
|
107
109
|
janus/utils/pdf_docs_reader.py,sha256=beMKHdYrFwg0m_i7n0OTJrut3sf4rEWFd7P_80A76WY,5140
|
108
110
|
janus/utils/progress.py,sha256=PIpcQec7SrhsfqB25LHj2CDDkfm9umZx90d9LZnAx6k,1469
|
109
|
-
janus_llm-4.
|
110
|
-
janus_llm-4.
|
111
|
-
janus_llm-4.
|
112
|
-
janus_llm-4.
|
113
|
-
janus_llm-4.
|
111
|
+
janus_llm-4.3.1.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
|
112
|
+
janus_llm-4.3.1.dist-info/METADATA,sha256=ZeUGDDKbJjHSk2Wkzf-4zXLIwaYZqua-5_HVFbzV2yg,4574
|
113
|
+
janus_llm-4.3.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
114
|
+
janus_llm-4.3.1.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
|
115
|
+
janus_llm-4.3.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|