janus-llm 3.3.2__py3-none-any.whl → 3.4.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +51 -0
- janus/converter/converter.py +63 -23
- janus/converter/requirements.py +5 -0
- janus/language/alc/alc.py +99 -1
- janus/language/block.py +2 -0
- janus/language/naive/simple_ast.py +67 -3
- janus/parsers/refiner_parser.py +3 -1
- janus/refiners/refiner.py +17 -7
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/METADATA +1 -1
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/RECORD +14 -14
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/LICENSE +0 -0
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/WHEEL +0 -0
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "3.
|
8
|
+
__version__ = "3.4.1"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -200,6 +200,14 @@ def translate(
|
|
200
200
|
help="Whether to overwrite existing files in the output directory",
|
201
201
|
),
|
202
202
|
] = False,
|
203
|
+
skip_context: Annotated[
|
204
|
+
bool,
|
205
|
+
typer.Option(
|
206
|
+
"--skip-context",
|
207
|
+
help="Prompts will include any context information associated with source"
|
208
|
+
" code blocks, unless this option is specified",
|
209
|
+
),
|
210
|
+
] = False,
|
203
211
|
temp: Annotated[
|
204
212
|
float,
|
205
213
|
typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
|
@@ -240,6 +248,13 @@ def translate(
|
|
240
248
|
"If unspecificed, model's default max will be used.",
|
241
249
|
),
|
242
250
|
] = None,
|
251
|
+
skip_refiner: Annotated[
|
252
|
+
bool,
|
253
|
+
typer.Option(
|
254
|
+
"--skip-refiner",
|
255
|
+
help="Whether to skip the refiner for generating output",
|
256
|
+
),
|
257
|
+
] = True,
|
243
258
|
):
|
244
259
|
try:
|
245
260
|
target_language, target_version = target_lang.split("-")
|
@@ -265,6 +280,8 @@ def translate(
|
|
265
280
|
db_path=db_loc,
|
266
281
|
db_config=collections_config,
|
267
282
|
splitter_type=splitter_type,
|
283
|
+
skip_context=skip_context,
|
284
|
+
skip_refiner=skip_refiner,
|
268
285
|
)
|
269
286
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
270
287
|
|
@@ -322,6 +339,14 @@ def document(
|
|
322
339
|
help="Whether to overwrite existing files in the output directory",
|
323
340
|
),
|
324
341
|
] = False,
|
342
|
+
skip_context: Annotated[
|
343
|
+
bool,
|
344
|
+
typer.Option(
|
345
|
+
"--skip-context",
|
346
|
+
help="Prompts will include any context information associated with source"
|
347
|
+
" code blocks, unless this option is specified",
|
348
|
+
),
|
349
|
+
] = False,
|
325
350
|
doc_mode: Annotated[
|
326
351
|
str,
|
327
352
|
typer.Option(
|
@@ -378,6 +403,13 @@ def document(
|
|
378
403
|
"If unspecificed, model's default max will be used.",
|
379
404
|
),
|
380
405
|
] = None,
|
406
|
+
skip_refiner: Annotated[
|
407
|
+
bool,
|
408
|
+
typer.Option(
|
409
|
+
"--skip-refiner",
|
410
|
+
help="Whether to skip the refiner for generating output",
|
411
|
+
),
|
412
|
+
] = True,
|
381
413
|
):
|
382
414
|
model_arguments = dict(temperature=temperature)
|
383
415
|
collections_config = get_collections_config()
|
@@ -390,6 +422,8 @@ def document(
|
|
390
422
|
db_path=db_loc,
|
391
423
|
db_config=collections_config,
|
392
424
|
splitter_type=splitter_type,
|
425
|
+
skip_refiner=skip_refiner,
|
426
|
+
skip_context=skip_context,
|
393
427
|
)
|
394
428
|
if doc_mode == "madlibs":
|
395
429
|
documenter = MadLibsDocumenter(
|
@@ -458,6 +492,14 @@ def diagram(
|
|
458
492
|
help="Whether to overwrite existing files in the output directory",
|
459
493
|
),
|
460
494
|
] = False,
|
495
|
+
skip_context: Annotated[
|
496
|
+
bool,
|
497
|
+
typer.Option(
|
498
|
+
"--skip-context",
|
499
|
+
help="Prompts will include any context information associated with source"
|
500
|
+
" code blocks, unless this option is specified",
|
501
|
+
),
|
502
|
+
] = False,
|
461
503
|
temperature: Annotated[
|
462
504
|
float,
|
463
505
|
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
@@ -494,6 +536,13 @@ def diagram(
|
|
494
536
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
495
537
|
),
|
496
538
|
] = "file",
|
539
|
+
skip_refiner: Annotated[
|
540
|
+
bool,
|
541
|
+
typer.Option(
|
542
|
+
"--skip-refiner",
|
543
|
+
help="Whether to skip the refiner for generating output",
|
544
|
+
),
|
545
|
+
] = True,
|
497
546
|
):
|
498
547
|
model_arguments = dict(temperature=temperature)
|
499
548
|
collections_config = get_collections_config()
|
@@ -507,6 +556,8 @@ def diagram(
|
|
507
556
|
diagram_type=diagram_type,
|
508
557
|
add_documentation=add_documentation,
|
509
558
|
splitter_type=splitter_type,
|
559
|
+
skip_refiner=skip_refiner,
|
560
|
+
skip_context=skip_context,
|
510
561
|
)
|
511
562
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
512
563
|
|
janus/converter/converter.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
1
|
import functools
|
2
2
|
import json
|
3
|
-
import math
|
4
3
|
import time
|
5
4
|
from pathlib import Path
|
6
|
-
from typing import Any
|
5
|
+
from typing import Any, List, Optional, Tuple
|
7
6
|
|
8
7
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
9
8
|
from langchain_core.exceptions import OutputParserException
|
10
9
|
from langchain_core.language_models import BaseLanguageModel
|
11
10
|
from langchain_core.output_parsers import BaseOutputParser
|
12
|
-
from langchain_core.prompts import ChatPromptTemplate
|
11
|
+
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
|
13
12
|
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
14
13
|
from openai import BadRequestError, RateLimitError
|
15
14
|
from pydantic import ValidationError
|
@@ -77,6 +76,8 @@ class Converter:
|
|
77
76
|
prune_node_types: tuple[str, ...] = (),
|
78
77
|
splitter_type: str = "file",
|
79
78
|
refiner_type: str = "basic",
|
79
|
+
skip_refiner: bool = True,
|
80
|
+
skip_context: bool = False,
|
80
81
|
) -> None:
|
81
82
|
"""Initialize a Converter instance.
|
82
83
|
|
@@ -97,6 +98,8 @@ class Converter:
|
|
97
98
|
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
99
|
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
100
|
refiner_type: The type of refiner to use. Valid values are `"basic"`.
|
101
|
+
skip_refiner: Whether to skip the refiner.
|
102
|
+
skip_context: Whether to skip adding context to the prompt.
|
100
103
|
"""
|
101
104
|
self._changed_attrs: set = set()
|
102
105
|
|
@@ -132,6 +135,8 @@ class Converter:
|
|
132
135
|
self._refiner_type: str
|
133
136
|
self._refiner: Refiner
|
134
137
|
|
138
|
+
self.skip_refiner = skip_refiner
|
139
|
+
|
135
140
|
self.set_splitter(splitter_type=splitter_type)
|
136
141
|
self.set_refiner(refiner_type=refiner_type)
|
137
142
|
self.set_model(model_name=model, **model_arguments)
|
@@ -142,6 +147,8 @@ class Converter:
|
|
142
147
|
self.set_db_path(db_path=db_path)
|
143
148
|
self.set_db_config(db_config=db_config)
|
144
149
|
|
150
|
+
self.skip_context = skip_context
|
151
|
+
|
145
152
|
# Child class must call this. Should we enforce somehow?
|
146
153
|
# self._load_parameters()
|
147
154
|
|
@@ -290,7 +297,7 @@ class Converter:
|
|
290
297
|
"""
|
291
298
|
if self._refiner_type == "basic":
|
292
299
|
self._refiner = BasicRefiner(
|
293
|
-
"basic_refinement", self.
|
300
|
+
"basic_refinement", self._model_id, self._source_language
|
294
301
|
)
|
295
302
|
else:
|
296
303
|
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
@@ -595,37 +602,41 @@ class Converter:
|
|
595
602
|
self._parser.set_reference(block.original)
|
596
603
|
|
597
604
|
# Retries with just the output and the error
|
598
|
-
n1 = round(self.max_prompts ** (1 /
|
605
|
+
n1 = round(self.max_prompts ** (1 / 2))
|
599
606
|
|
600
607
|
# Retries with the input, output, and error
|
601
|
-
n2 = round(
|
608
|
+
n2 = round(self.max_prompts // n1)
|
602
609
|
|
603
610
|
# Retries with just the input
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
611
|
+
if not self.skip_context:
|
612
|
+
self._make_prompt_additions(block)
|
613
|
+
if not self.skip_refiner: # Make replacements in the prompt
|
614
|
+
refine_output = RefinerParser(
|
615
|
+
parser=self._parser,
|
616
|
+
initial_prompt=self._prompt.format(
|
617
|
+
**{"SOURCE_CODE": block.original.text}
|
618
|
+
),
|
619
|
+
refiner=self._refiner,
|
620
|
+
max_retries=n1,
|
621
|
+
llm=self._llm,
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
refine_output = RetryWithErrorOutputParser.from_llm(
|
625
|
+
llm=self._llm,
|
626
|
+
parser=self._parser,
|
627
|
+
max_retries=n1,
|
628
|
+
)
|
618
629
|
completion_chain = self._prompt | self._llm
|
619
630
|
chain = RunnableParallel(
|
620
631
|
completion=completion_chain, prompt_value=self._prompt
|
621
|
-
) | RunnableLambda(lambda x:
|
622
|
-
for _ in range(
|
632
|
+
) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
|
633
|
+
for _ in range(n2):
|
623
634
|
try:
|
624
635
|
return chain.invoke({"SOURCE_CODE": block.original.text})
|
625
636
|
except OutputParserException:
|
626
637
|
pass
|
627
638
|
|
628
|
-
raise OutputParserException(f"Failed to parse after {n1*n2
|
639
|
+
raise OutputParserException(f"Failed to parse after {n1*n2} retries")
|
629
640
|
|
630
641
|
def _get_output_obj(
|
631
642
|
self, block: TranslatedCodeBlock
|
@@ -648,6 +659,35 @@ class Converter:
|
|
648
659
|
output=output,
|
649
660
|
)
|
650
661
|
|
662
|
+
@staticmethod
|
663
|
+
def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
|
664
|
+
"""Get a list of strings to append to the prompt.
|
665
|
+
|
666
|
+
Arguments:
|
667
|
+
block: The `TranslatedCodeBlock` to save to a file.
|
668
|
+
"""
|
669
|
+
return [(key, item) for key, item in block.context_tags.items()]
|
670
|
+
|
671
|
+
def _make_prompt_additions(self, block: CodeBlock):
|
672
|
+
# Prepare the additional context to prepend
|
673
|
+
additional_context = "".join(
|
674
|
+
[
|
675
|
+
f"{context_tag}: {context}\n"
|
676
|
+
for context_tag, context in self._get_prompt_additions(block)
|
677
|
+
]
|
678
|
+
)
|
679
|
+
|
680
|
+
# Iterate through existing messages to find and update the system message
|
681
|
+
for i, message in enumerate(self._prompt.messages):
|
682
|
+
if isinstance(message, SystemMessagePromptTemplate):
|
683
|
+
# Prepend the additional context to the system message
|
684
|
+
updated_system_message = SystemMessagePromptTemplate.from_template(
|
685
|
+
additional_context + message.prompt.template
|
686
|
+
)
|
687
|
+
# Directly modify the message in the list
|
688
|
+
self._prompt.messages[i] = updated_system_message
|
689
|
+
break # Assuming there's only one system message to update
|
690
|
+
|
651
691
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
652
692
|
"""Save a file to disk.
|
653
693
|
|
janus/converter/requirements.py
CHANGED
@@ -22,6 +22,11 @@ class RequirementsDocumenter(Documenter):
|
|
22
22
|
self._combiner = ChunkCombiner()
|
23
23
|
self._parser = RequirementsParser()
|
24
24
|
|
25
|
+
@staticmethod
|
26
|
+
def get_prompt_replacements(block) -> dict[str, str]:
|
27
|
+
prompt_replacements: dict[str, str] = {"SOURCE_CODE": block.original.text}
|
28
|
+
return prompt_replacements
|
29
|
+
|
25
30
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
26
31
|
"""Save a file to disk.
|
27
32
|
|
janus/language/alc/alc.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Optional
|
3
|
+
|
1
4
|
from langchain.schema.language_model import BaseLanguageModel
|
2
5
|
|
3
6
|
from janus.language.block import CodeBlock
|
@@ -61,7 +64,11 @@ class AlcSplitter(TreeSitterSplitter):
|
|
61
64
|
# next csect or dsect instruction
|
62
65
|
sects: list[list[CodeBlock]] = [[]]
|
63
66
|
for c in block.children:
|
64
|
-
if c.node_type
|
67
|
+
if c.node_type == "csect_instruction":
|
68
|
+
c.context_tags["alc_section"] = "CSECT"
|
69
|
+
sects.append([c])
|
70
|
+
elif c.node_type == "dsect_instruction":
|
71
|
+
c.context_tags["alc_section"] = "DSECT"
|
65
72
|
sects.append([c])
|
66
73
|
else:
|
67
74
|
sects[-1].append(c)
|
@@ -85,3 +92,94 @@ class AlcSplitter(TreeSitterSplitter):
|
|
85
92
|
queue.extend(block.children)
|
86
93
|
|
87
94
|
return root
|
95
|
+
|
96
|
+
|
97
|
+
class AlcListingSplitter(AlcSplitter):
|
98
|
+
"""A class for splitting ALC listing code into functional blocks to
|
99
|
+
prompt with for transcoding.
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
model: None | BaseLanguageModel = None,
|
105
|
+
max_tokens: int = 4096,
|
106
|
+
protected_node_types: tuple[str, ...] = (),
|
107
|
+
prune_node_types: tuple[str, ...] = (),
|
108
|
+
prune_unprotected: bool = False,
|
109
|
+
):
|
110
|
+
"""Initialize a AlcSplitter instance.
|
111
|
+
|
112
|
+
|
113
|
+
Arguments:
|
114
|
+
max_tokens: The maximum number of tokens supported by the model
|
115
|
+
"""
|
116
|
+
# The string to mark the end of the listing header
|
117
|
+
self.header_indicator_str: str = (
|
118
|
+
"Loc Object Code Addr1 Addr2 Stmt Source Statement"
|
119
|
+
)
|
120
|
+
# How many characters to trim from the right side to remove the address column
|
121
|
+
self.address_column_chars: int = 10
|
122
|
+
# The string to mark the end of the left margin
|
123
|
+
self.left_margin_indicator_str: str = "Stmt"
|
124
|
+
super().__init__(
|
125
|
+
model=model,
|
126
|
+
max_tokens=max_tokens,
|
127
|
+
protected_node_types=protected_node_types,
|
128
|
+
prune_node_types=prune_node_types,
|
129
|
+
prune_unprotected=prune_unprotected,
|
130
|
+
)
|
131
|
+
|
132
|
+
def _get_ast(self, code: str) -> CodeBlock:
|
133
|
+
active_usings = self.get_active_usings(code)
|
134
|
+
code = self.preproccess_assembly(code)
|
135
|
+
ast: CodeBlock = super()._get_ast(code)
|
136
|
+
ast.context_tags["active_usings"] = active_usings
|
137
|
+
return ast
|
138
|
+
|
139
|
+
def preproccess_assembly(self, code: str) -> str:
|
140
|
+
"""Remove non-essential lines from an assembly snippet"""
|
141
|
+
|
142
|
+
lines = code.splitlines()
|
143
|
+
lines = self.strip_header_and_left(lines)
|
144
|
+
lines = self.strip_addresses(lines)
|
145
|
+
return "".join(str(line) for line in lines)
|
146
|
+
|
147
|
+
def get_active_usings(self, code: str) -> Optional[str]:
|
148
|
+
"""Look for 'active usings' in the ALC listing header"""
|
149
|
+
lines = code.splitlines()
|
150
|
+
for line in lines:
|
151
|
+
if "Active Usings:" in line:
|
152
|
+
return line.split("Active Usings:")[1]
|
153
|
+
return None
|
154
|
+
|
155
|
+
def strip_header_and_left(
|
156
|
+
self,
|
157
|
+
lines: list[str],
|
158
|
+
) -> list[str]:
|
159
|
+
"""Remove the header and the left panel from the assembly sample"""
|
160
|
+
|
161
|
+
esd_regex = re.compile(f".*{self.header_indicator_str}.*")
|
162
|
+
|
163
|
+
header_end_index: int = [
|
164
|
+
i for i, item in enumerate(lines) if re.search(esd_regex, item)
|
165
|
+
][0]
|
166
|
+
|
167
|
+
left_content_end_column = lines[header_end_index].find(
|
168
|
+
self.left_margin_indicator_str
|
169
|
+
)
|
170
|
+
hori_output_lines = lines[(header_end_index + 1) :]
|
171
|
+
|
172
|
+
left_output_lines = [
|
173
|
+
line[left_content_end_column + 5 :] for line in hori_output_lines
|
174
|
+
]
|
175
|
+
return left_output_lines
|
176
|
+
|
177
|
+
def strip_addresses(self, lines: list[str]) -> list[str]:
|
178
|
+
"""Strip the addresses which run down the right side of the assembly snippet"""
|
179
|
+
|
180
|
+
stripped_lines = [line[: -self.address_column_chars] for line in lines]
|
181
|
+
return stripped_lines
|
182
|
+
|
183
|
+
def strip_footer(self, lines: list[str]):
|
184
|
+
"""Strip the footer from the assembly snippet"""
|
185
|
+
return NotImplementedError
|
janus/language/block.py
CHANGED
@@ -45,6 +45,7 @@ class CodeBlock:
|
|
45
45
|
children: list[ForwardRef("CodeBlock")],
|
46
46
|
embedding_id: Optional[str] = None,
|
47
47
|
affixes: Tuple[str, str] = ("", ""),
|
48
|
+
context_tags: dict[str, str] = {},
|
48
49
|
) -> None:
|
49
50
|
self.id: Hashable = id
|
50
51
|
self.name: Optional[str] = name
|
@@ -59,6 +60,7 @@ class CodeBlock:
|
|
59
60
|
self.children: list[ForwardRef("CodeBlock")] = sorted(children)
|
60
61
|
self.embedding_id: Optional[str] = embedding_id
|
61
62
|
self.affixes: Tuple[str, str] = affixes
|
63
|
+
self.context_tags: dict[str, str] = context_tags
|
62
64
|
|
63
65
|
self.complete = True
|
64
66
|
self.omit_prefix = True
|
@@ -1,12 +1,24 @@
|
|
1
|
-
from janus.language.alc.alc import AlcSplitter
|
1
|
+
from janus.language.alc.alc import AlcListingSplitter, AlcSplitter
|
2
2
|
from janus.language.mumps.mumps import MumpsSplitter
|
3
3
|
from janus.language.naive.registry import register_splitter
|
4
|
+
from janus.language.splitter import Splitter
|
4
5
|
from janus.language.treesitter import TreeSitterSplitter
|
5
6
|
from janus.utils.enums import LANGUAGES
|
7
|
+
from janus.utils.logger import create_logger
|
8
|
+
|
9
|
+
log = create_logger(__name__)
|
6
10
|
|
7
11
|
|
8
12
|
@register_splitter("ast-flex")
|
9
|
-
def get_flexible_ast(language: str, **kwargs):
|
13
|
+
def get_flexible_ast(language: str, **kwargs) -> Splitter:
|
14
|
+
"""Get a flexible AST splitter for the given language.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
language: The language to get the splitter for.
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
A flexible AST splitter for the given language.
|
21
|
+
"""
|
10
22
|
if language == "ibmhlasm":
|
11
23
|
return AlcSplitter(**kwargs)
|
12
24
|
elif language == "mumps":
|
@@ -16,7 +28,17 @@ def get_flexible_ast(language: str, **kwargs):
|
|
16
28
|
|
17
29
|
|
18
30
|
@register_splitter("ast-strict")
|
19
|
-
def get_strict_ast(language: str, **kwargs):
|
31
|
+
def get_strict_ast(language: str, **kwargs) -> Splitter:
|
32
|
+
"""Get a strict AST splitter for the given language.
|
33
|
+
|
34
|
+
The strict splitter will only return nodes that are of a functional type.
|
35
|
+
|
36
|
+
Arguments:
|
37
|
+
language: The language to get the splitter for.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
A strict AST splitter for the given language.
|
41
|
+
"""
|
20
42
|
kwargs.update(
|
21
43
|
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
22
44
|
prune_unprotected=True,
|
@@ -27,3 +49,45 @@ def get_strict_ast(language: str, **kwargs):
|
|
27
49
|
return MumpsSplitter(**kwargs)
|
28
50
|
else:
|
29
51
|
return TreeSitterSplitter(language=language, **kwargs)
|
52
|
+
|
53
|
+
|
54
|
+
@register_splitter("ast-strict-listing")
|
55
|
+
def get_strict_listing_ast(language: str, **kwargs) -> Splitter:
|
56
|
+
"""Get a strict AST splitter for the given language. This splitter is intended for
|
57
|
+
use with IBM HLASM.
|
58
|
+
|
59
|
+
The strict splitter will only return nodes that are of a functional type.
|
60
|
+
|
61
|
+
Arguments:
|
62
|
+
language: The language to get the splitter for.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
A strict AST splitter for the given language.
|
66
|
+
"""
|
67
|
+
kwargs.update(
|
68
|
+
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
69
|
+
prune_unprotected=True,
|
70
|
+
)
|
71
|
+
if language == "ibmhlasm":
|
72
|
+
return AlcListingSplitter(**kwargs)
|
73
|
+
else:
|
74
|
+
log.warning("Listing splitter is only intended for use with IBMHLASM!")
|
75
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
76
|
+
|
77
|
+
|
78
|
+
@register_splitter("ast-flex-listing")
|
79
|
+
def get_flexible_listing_ast(language: str, **kwargs) -> Splitter:
|
80
|
+
"""Get a flexible AST splitter for the given language. This splitter is intended for
|
81
|
+
use with IBM HLASM.
|
82
|
+
|
83
|
+
Arguments:
|
84
|
+
language: The language to get the splitter for.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
A flexible AST splitter for the given language.
|
88
|
+
"""
|
89
|
+
if language == "ibmhlasm":
|
90
|
+
return AlcListingSplitter(**kwargs)
|
91
|
+
else:
|
92
|
+
log.warning("Listing splitter is only intended for use with IBMHLASM!")
|
93
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
janus/parsers/refiner_parser.py
CHANGED
@@ -40,7 +40,9 @@ class RefinerParser(BaseOutputParser):
|
|
40
40
|
return self.parser.parse(text)
|
41
41
|
except OutputParserException as oe:
|
42
42
|
err = str(oe)
|
43
|
-
new_prompt, prompt_arguments = self.refiner.refine(
|
43
|
+
new_prompt, prompt_arguments = self.refiner.refine(
|
44
|
+
self.initial_prompt, last_prompt, text, err
|
45
|
+
)
|
44
46
|
new_chain = new_prompt | self.llm
|
45
47
|
text = new_chain.invoke(prompt_arguments)
|
46
48
|
last_prompt = new_prompt.format(**prompt_arguments)
|
janus/refiners/refiner.py
CHANGED
@@ -5,7 +5,12 @@ from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
|
5
5
|
|
6
6
|
class Refiner:
|
7
7
|
def refine(
|
8
|
-
self,
|
8
|
+
self,
|
9
|
+
original_prompt: str,
|
10
|
+
previous_prompt: str,
|
11
|
+
previous_output: str,
|
12
|
+
errors: str,
|
13
|
+
**kwargs,
|
9
14
|
) -> tuple[ChatPromptTemplate, dict[str, str]]:
|
10
15
|
"""Creates a new prompt based on feedback from original results
|
11
16
|
|
@@ -24,22 +29,27 @@ class BasicRefiner(Refiner):
|
|
24
29
|
def __init__(
|
25
30
|
self,
|
26
31
|
prompt_name: str,
|
27
|
-
|
32
|
+
model_id: str,
|
28
33
|
source_language: str,
|
29
34
|
) -> None:
|
30
35
|
"""Basic refiner, asks llm to fix output of previous prompt given errors
|
31
36
|
|
32
37
|
Arguments:
|
33
38
|
prompt_name: refinement prompt name to use
|
34
|
-
|
39
|
+
model_id: ID of the llm to use. Found in models_info.py
|
35
40
|
source_language: source_langauge to use
|
36
41
|
"""
|
37
42
|
self._prompt_name = prompt_name
|
38
|
-
self.
|
43
|
+
self._model_id = model_id
|
39
44
|
self._source_language = source_language
|
40
45
|
|
41
46
|
def refine(
|
42
|
-
self,
|
47
|
+
self,
|
48
|
+
original_prompt: str,
|
49
|
+
previous_prompt: str,
|
50
|
+
previous_output: str,
|
51
|
+
errors: str,
|
52
|
+
**kwargs,
|
43
53
|
) -> tuple[ChatPromptTemplate, dict[str, str]]:
|
44
54
|
"""Creates a new prompt based on feedback from original results
|
45
55
|
|
@@ -51,13 +61,13 @@ class BasicRefiner(Refiner):
|
|
51
61
|
Returns:
|
52
62
|
Tuple of new prompt and prompt arguments
|
53
63
|
"""
|
54
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
64
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
|
55
65
|
prompt_template=self._prompt_name,
|
56
66
|
source_language=self._source_language,
|
57
67
|
)
|
58
68
|
prompt_arguments = {
|
59
69
|
"ORIGINAL_PROMPT": original_prompt,
|
60
|
-
"OUTPUT":
|
70
|
+
"OUTPUT": previous_output,
|
61
71
|
"ERRORS": errors,
|
62
72
|
}
|
63
73
|
return prompt_engine.prompt, prompt_arguments
|
@@ -1,17 +1,17 @@
|
|
1
|
-
janus/__init__.py,sha256=
|
1
|
+
janus/__init__.py,sha256=XGgjz8H0Qo2cAkx4nww-GdVfdiM8hrSJfv6mGsEwy1s,361
|
2
2
|
janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
|
3
3
|
janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
|
5
5
|
janus/_tests/test_cli.py,sha256=oYJsUGWfpBJWEGRG5NGxdJedU5DU_m6fwJ7xEbJVYl0,4244
|
6
|
-
janus/cli.py,sha256=
|
6
|
+
janus/cli.py,sha256=_92FvDV4qza0nSmyiXqacYxyo1gY6IPwD4gCm6kZfqI,33213
|
7
7
|
janus/converter/__init__.py,sha256=U2EOMcCykiC0ZqhorNefOP_04hOF18qhYoPKrVp1Vrk,345
|
8
8
|
janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
janus/converter/_tests/test_translate.py,sha256=yzcFEGc_z8QmBBBmC9dZnfL9tT8CD1rkpc8Hz44Jp4c,5631
|
10
|
-
janus/converter/converter.py,sha256=
|
10
|
+
janus/converter/converter.py,sha256=CgaNjQE6bz8qCpoYTEwuON40clOq-8r-lp2C173xS0E,27422
|
11
11
|
janus/converter/diagram.py,sha256=5mo1H3Y1uIBPYdIsWz9kxluN5DNyuUMZrtcJmGF2Uw0,5335
|
12
12
|
janus/converter/document.py,sha256=hsW512veNjFWbdl5WriuUdNmMEqZy8ktRvqn9rRmA6E,4566
|
13
13
|
janus/converter/evaluate.py,sha256=APWQUY3gjAXqkJkPzvj0UA4wPK3Cv9QSJLM-YK9t-ng,476
|
14
|
-
janus/converter/requirements.py,sha256=
|
14
|
+
janus/converter/requirements.py,sha256=9tvQ40FZJtG8niIFn45gPQCgKKHVPPoFLinBv6RAqO4,2027
|
15
15
|
janus/converter/translate.py,sha256=0brQTlSfBYmXtoM8QYIOiyr0LrTr0S1n68Du-BR7_WQ,4236
|
16
16
|
janus/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
janus/embedding/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -29,13 +29,13 @@ janus/language/_tests/test_splitter.py,sha256=Hqexa39LLEXlK3ZUw7Zot4PUIACvye2vkq
|
|
29
29
|
janus/language/alc/__init__.py,sha256=j7vOMGhT1Vri6p8dsjSaY-fkO5uFn0sJ0nrNGGvcizM,42
|
30
30
|
janus/language/alc/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
janus/language/alc/_tests/test_alc.py,sha256=NgVeOctm9zf-S328DdUNn9et_-lK1t5O0O2FKElb91Q,1027
|
32
|
-
janus/language/alc/alc.py,sha256=
|
32
|
+
janus/language/alc/alc.py,sha256=l1p6zwyE7ZzY9rnjsUZGuQW41hijFrzxnXnFOpfGq8k,6590
|
33
33
|
janus/language/binary/__init__.py,sha256=AlNAe12ZA366kcGSrQ1FJyOdbwxFqGBFkYR2K6yL818,51
|
34
34
|
janus/language/binary/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
35
|
janus/language/binary/_tests/test_binary.py,sha256=SDdI6tsQj9yXle7wBsksHuKULLMHv7mNgUkDx1nCvpw,1733
|
36
36
|
janus/language/binary/binary.py,sha256=jcc-LZx8Ss-g4j0a691U2yVJV6JGV_zpE_6y_aDq4Cw,6579
|
37
37
|
janus/language/binary/reveng/decompile_script.py,sha256=veW51oJzuO-4UD3Er062jXZ_FYtTFo9OCkl82Z2xr6A,2182
|
38
|
-
janus/language/block.py,sha256=
|
38
|
+
janus/language/block.py,sha256=2rjAYUosHFfWRgLnzf50uAgTMST4Md9Kx6JrlUfEfX4,9398
|
39
39
|
janus/language/combine.py,sha256=Wtve06fa-_Wjv_V5RIf1Nfmg0UxcOEtFNj4vVHpSNbo,2940
|
40
40
|
janus/language/file.py,sha256=jy-cReAoI6F97TXR5bbhPyt8XyUZCdFYnVboubDA_y4,571
|
41
41
|
janus/language/mumps/__init__.py,sha256=-Ou_wJ-JgHezfp1dub2_qCYNiK9wO-zo2MlqxM9qiwE,48
|
@@ -47,7 +47,7 @@ janus/language/naive/__init__.py,sha256=_Gq4inONyVYxe8WLB59d_69kqGbtF40BGKoJPnK4
|
|
47
47
|
janus/language/naive/basic_splitter.py,sha256=RM9pJK2YkHfb6_EFEV-dh_rLqkjS6v0cn3ASPf8A6Fg,459
|
48
48
|
janus/language/naive/chunk_splitter.py,sha256=ebRSbaJhDW-Hyr5__ukbdmAl6kQ1WWFqrq_SfCgHo6k,772
|
49
49
|
janus/language/naive/registry.py,sha256=8YQX1q0IdAm7t69-oC_00I-vfkdRnHuX-OD3KEjEIuU,294
|
50
|
-
janus/language/naive/simple_ast.py,sha256=
|
50
|
+
janus/language/naive/simple_ast.py,sha256=T53UwAZyRfePXzpiNUhe4FyDev6YcX1dVDxYkTcDRPE,3032
|
51
51
|
janus/language/naive/tag_splitter.py,sha256=IXWMn9tBVUGAtzvQi89GhoZ6g7fPXk5MzO0kMCr2mb0,2045
|
52
52
|
janus/language/node.py,sha256=baoYFtapwBQqBtUN6EvHFYRkbR-EcEw1b3fQvH9zIAM,204
|
53
53
|
janus/language/splitter.py,sha256=hITFp4a9bJ6sP74AWvC2GWa2Poo10MyHXYTj9hviXss,16970
|
@@ -87,12 +87,12 @@ janus/parsers/_tests/test_code_parser.py,sha256=RVgMmLvg8_57g0uJphfX-jZZsyBqOSuG
|
|
87
87
|
janus/parsers/code_parser.py,sha256=SZBsYThG4iszKlu4fHoWrs-6cbJiUFjWv4cLSr5bzDM,1790
|
88
88
|
janus/parsers/doc_parser.py,sha256=bJiOE5M7npUZur_1MWJ14C2HZl7-yXExqRXiC5ZBJvI,5679
|
89
89
|
janus/parsers/eval_parser.py,sha256=L1Lu2aNimcqUshe0FQee_9Zqj1rrqyZPXCgEAS05VJ4,2740
|
90
|
-
janus/parsers/refiner_parser.py,sha256=
|
90
|
+
janus/parsers/refiner_parser.py,sha256=5zGoPZyttfRw3kXYDKHId2nhVvAcv1QsvpDFRpe-Few,1680
|
91
91
|
janus/parsers/reqs_parser.py,sha256=6YzpF63rjuDPqpKWfYvtjpsluWQ-UboWlsKoGrGQogA,2380
|
92
92
|
janus/parsers/uml.py,sha256=ZRyGY8YxvYibacTd-WZEAAaW3XjmvJhPJE3o29f71t8,1825
|
93
93
|
janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
94
94
|
janus/prompts/prompt.py,sha256=3796YXIzzIec9b0iUzd8VZlq-AdQbzq8qUGXLy4KH-0,10586
|
95
|
-
janus/refiners/refiner.py,sha256=
|
95
|
+
janus/refiners/refiner.py,sha256=GkV4oUSCrLAhyDJY2aY_Jt8PRF3sC6-bv58YbL2PaNk,2227
|
96
96
|
janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
97
97
|
janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
98
98
|
janus/utils/_tests/test_logger.py,sha256=jkkvrCTKwsFCsZtmyuvc-WJ0rC7LJi2Z91sIe4IiKzA,2209
|
@@ -100,8 +100,8 @@ janus/utils/_tests/test_progress.py,sha256=Rs_u5PiGjP-L-o6C1fhwfE1ig8jYu9Xo9s4p8
|
|
100
100
|
janus/utils/enums.py,sha256=AoilbdiYyMvY2Mp0AM4xlbLSELfut2XMwhIM1S_msP4,27610
|
101
101
|
janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
|
102
102
|
janus/utils/progress.py,sha256=PIpcQec7SrhsfqB25LHj2CDDkfm9umZx90d9LZnAx6k,1469
|
103
|
-
janus_llm-3.
|
104
|
-
janus_llm-3.
|
105
|
-
janus_llm-3.
|
106
|
-
janus_llm-3.
|
107
|
-
janus_llm-3.
|
103
|
+
janus_llm-3.4.1.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
|
104
|
+
janus_llm-3.4.1.dist-info/METADATA,sha256=lQxVgmdepmQHUey6HX2z0wTUEg6XqpaCu91kWpbeq_E,4184
|
105
|
+
janus_llm-3.4.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
106
|
+
janus_llm-3.4.1.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
|
107
|
+
janus_llm-3.4.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|