janus-llm 3.3.2__py3-none-any.whl → 3.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +51 -0
- janus/converter/converter.py +63 -23
- janus/converter/requirements.py +5 -0
- janus/language/alc/alc.py +99 -1
- janus/language/block.py +2 -0
- janus/language/naive/simple_ast.py +67 -3
- janus/parsers/refiner_parser.py +3 -1
- janus/refiners/refiner.py +17 -7
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/METADATA +1 -1
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/RECORD +14 -14
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/LICENSE +0 -0
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/WHEEL +0 -0
- {janus_llm-3.3.2.dist-info → janus_llm-3.4.1.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "3.
|
8
|
+
__version__ = "3.4.1"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -200,6 +200,14 @@ def translate(
|
|
200
200
|
help="Whether to overwrite existing files in the output directory",
|
201
201
|
),
|
202
202
|
] = False,
|
203
|
+
skip_context: Annotated[
|
204
|
+
bool,
|
205
|
+
typer.Option(
|
206
|
+
"--skip-context",
|
207
|
+
help="Prompts will include any context information associated with source"
|
208
|
+
" code blocks, unless this option is specified",
|
209
|
+
),
|
210
|
+
] = False,
|
203
211
|
temp: Annotated[
|
204
212
|
float,
|
205
213
|
typer.Option("--temperature", "-T", help="Sampling temperature.", min=0, max=2),
|
@@ -240,6 +248,13 @@ def translate(
|
|
240
248
|
"If unspecificed, model's default max will be used.",
|
241
249
|
),
|
242
250
|
] = None,
|
251
|
+
skip_refiner: Annotated[
|
252
|
+
bool,
|
253
|
+
typer.Option(
|
254
|
+
"--skip-refiner",
|
255
|
+
help="Whether to skip the refiner for generating output",
|
256
|
+
),
|
257
|
+
] = True,
|
243
258
|
):
|
244
259
|
try:
|
245
260
|
target_language, target_version = target_lang.split("-")
|
@@ -265,6 +280,8 @@ def translate(
|
|
265
280
|
db_path=db_loc,
|
266
281
|
db_config=collections_config,
|
267
282
|
splitter_type=splitter_type,
|
283
|
+
skip_context=skip_context,
|
284
|
+
skip_refiner=skip_refiner,
|
268
285
|
)
|
269
286
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
270
287
|
|
@@ -322,6 +339,14 @@ def document(
|
|
322
339
|
help="Whether to overwrite existing files in the output directory",
|
323
340
|
),
|
324
341
|
] = False,
|
342
|
+
skip_context: Annotated[
|
343
|
+
bool,
|
344
|
+
typer.Option(
|
345
|
+
"--skip-context",
|
346
|
+
help="Prompts will include any context information associated with source"
|
347
|
+
" code blocks, unless this option is specified",
|
348
|
+
),
|
349
|
+
] = False,
|
325
350
|
doc_mode: Annotated[
|
326
351
|
str,
|
327
352
|
typer.Option(
|
@@ -378,6 +403,13 @@ def document(
|
|
378
403
|
"If unspecificed, model's default max will be used.",
|
379
404
|
),
|
380
405
|
] = None,
|
406
|
+
skip_refiner: Annotated[
|
407
|
+
bool,
|
408
|
+
typer.Option(
|
409
|
+
"--skip-refiner",
|
410
|
+
help="Whether to skip the refiner for generating output",
|
411
|
+
),
|
412
|
+
] = True,
|
381
413
|
):
|
382
414
|
model_arguments = dict(temperature=temperature)
|
383
415
|
collections_config = get_collections_config()
|
@@ -390,6 +422,8 @@ def document(
|
|
390
422
|
db_path=db_loc,
|
391
423
|
db_config=collections_config,
|
392
424
|
splitter_type=splitter_type,
|
425
|
+
skip_refiner=skip_refiner,
|
426
|
+
skip_context=skip_context,
|
393
427
|
)
|
394
428
|
if doc_mode == "madlibs":
|
395
429
|
documenter = MadLibsDocumenter(
|
@@ -458,6 +492,14 @@ def diagram(
|
|
458
492
|
help="Whether to overwrite existing files in the output directory",
|
459
493
|
),
|
460
494
|
] = False,
|
495
|
+
skip_context: Annotated[
|
496
|
+
bool,
|
497
|
+
typer.Option(
|
498
|
+
"--skip-context",
|
499
|
+
help="Prompts will include any context information associated with source"
|
500
|
+
" code blocks, unless this option is specified",
|
501
|
+
),
|
502
|
+
] = False,
|
461
503
|
temperature: Annotated[
|
462
504
|
float,
|
463
505
|
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
@@ -494,6 +536,13 @@ def diagram(
|
|
494
536
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
495
537
|
),
|
496
538
|
] = "file",
|
539
|
+
skip_refiner: Annotated[
|
540
|
+
bool,
|
541
|
+
typer.Option(
|
542
|
+
"--skip-refiner",
|
543
|
+
help="Whether to skip the refiner for generating output",
|
544
|
+
),
|
545
|
+
] = True,
|
497
546
|
):
|
498
547
|
model_arguments = dict(temperature=temperature)
|
499
548
|
collections_config = get_collections_config()
|
@@ -507,6 +556,8 @@ def diagram(
|
|
507
556
|
diagram_type=diagram_type,
|
508
557
|
add_documentation=add_documentation,
|
509
558
|
splitter_type=splitter_type,
|
559
|
+
skip_refiner=skip_refiner,
|
560
|
+
skip_context=skip_context,
|
510
561
|
)
|
511
562
|
diagram_generator.translate(input_dir, output_dir, overwrite, collection)
|
512
563
|
|
janus/converter/converter.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
1
|
import functools
|
2
2
|
import json
|
3
|
-
import math
|
4
3
|
import time
|
5
4
|
from pathlib import Path
|
6
|
-
from typing import Any
|
5
|
+
from typing import Any, List, Optional, Tuple
|
7
6
|
|
8
7
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
9
8
|
from langchain_core.exceptions import OutputParserException
|
10
9
|
from langchain_core.language_models import BaseLanguageModel
|
11
10
|
from langchain_core.output_parsers import BaseOutputParser
|
12
|
-
from langchain_core.prompts import ChatPromptTemplate
|
11
|
+
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
|
13
12
|
from langchain_core.runnables import RunnableLambda, RunnableParallel
|
14
13
|
from openai import BadRequestError, RateLimitError
|
15
14
|
from pydantic import ValidationError
|
@@ -77,6 +76,8 @@ class Converter:
|
|
77
76
|
prune_node_types: tuple[str, ...] = (),
|
78
77
|
splitter_type: str = "file",
|
79
78
|
refiner_type: str = "basic",
|
79
|
+
skip_refiner: bool = True,
|
80
|
+
skip_context: bool = False,
|
80
81
|
) -> None:
|
81
82
|
"""Initialize a Converter instance.
|
82
83
|
|
@@ -97,6 +98,8 @@ class Converter:
|
|
97
98
|
splitter_type: The type of splitter to use. Valid values are `"file"`,
|
98
99
|
`"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
|
99
100
|
refiner_type: The type of refiner to use. Valid values are `"basic"`.
|
101
|
+
skip_refiner: Whether to skip the refiner.
|
102
|
+
skip_context: Whether to skip adding context to the prompt.
|
100
103
|
"""
|
101
104
|
self._changed_attrs: set = set()
|
102
105
|
|
@@ -132,6 +135,8 @@ class Converter:
|
|
132
135
|
self._refiner_type: str
|
133
136
|
self._refiner: Refiner
|
134
137
|
|
138
|
+
self.skip_refiner = skip_refiner
|
139
|
+
|
135
140
|
self.set_splitter(splitter_type=splitter_type)
|
136
141
|
self.set_refiner(refiner_type=refiner_type)
|
137
142
|
self.set_model(model_name=model, **model_arguments)
|
@@ -142,6 +147,8 @@ class Converter:
|
|
142
147
|
self.set_db_path(db_path=db_path)
|
143
148
|
self.set_db_config(db_config=db_config)
|
144
149
|
|
150
|
+
self.skip_context = skip_context
|
151
|
+
|
145
152
|
# Child class must call this. Should we enforce somehow?
|
146
153
|
# self._load_parameters()
|
147
154
|
|
@@ -290,7 +297,7 @@ class Converter:
|
|
290
297
|
"""
|
291
298
|
if self._refiner_type == "basic":
|
292
299
|
self._refiner = BasicRefiner(
|
293
|
-
"basic_refinement", self.
|
300
|
+
"basic_refinement", self._model_id, self._source_language
|
294
301
|
)
|
295
302
|
else:
|
296
303
|
raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
|
@@ -595,37 +602,41 @@ class Converter:
|
|
595
602
|
self._parser.set_reference(block.original)
|
596
603
|
|
597
604
|
# Retries with just the output and the error
|
598
|
-
n1 = round(self.max_prompts ** (1 /
|
605
|
+
n1 = round(self.max_prompts ** (1 / 2))
|
599
606
|
|
600
607
|
# Retries with the input, output, and error
|
601
|
-
n2 = round(
|
608
|
+
n2 = round(self.max_prompts // n1)
|
602
609
|
|
603
610
|
# Retries with just the input
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
611
|
+
if not self.skip_context:
|
612
|
+
self._make_prompt_additions(block)
|
613
|
+
if not self.skip_refiner: # Make replacements in the prompt
|
614
|
+
refine_output = RefinerParser(
|
615
|
+
parser=self._parser,
|
616
|
+
initial_prompt=self._prompt.format(
|
617
|
+
**{"SOURCE_CODE": block.original.text}
|
618
|
+
),
|
619
|
+
refiner=self._refiner,
|
620
|
+
max_retries=n1,
|
621
|
+
llm=self._llm,
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
refine_output = RetryWithErrorOutputParser.from_llm(
|
625
|
+
llm=self._llm,
|
626
|
+
parser=self._parser,
|
627
|
+
max_retries=n1,
|
628
|
+
)
|
618
629
|
completion_chain = self._prompt | self._llm
|
619
630
|
chain = RunnableParallel(
|
620
631
|
completion=completion_chain, prompt_value=self._prompt
|
621
|
-
) | RunnableLambda(lambda x:
|
622
|
-
for _ in range(
|
632
|
+
) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
|
633
|
+
for _ in range(n2):
|
623
634
|
try:
|
624
635
|
return chain.invoke({"SOURCE_CODE": block.original.text})
|
625
636
|
except OutputParserException:
|
626
637
|
pass
|
627
638
|
|
628
|
-
raise OutputParserException(f"Failed to parse after {n1*n2
|
639
|
+
raise OutputParserException(f"Failed to parse after {n1*n2} retries")
|
629
640
|
|
630
641
|
def _get_output_obj(
|
631
642
|
self, block: TranslatedCodeBlock
|
@@ -648,6 +659,35 @@ class Converter:
|
|
648
659
|
output=output,
|
649
660
|
)
|
650
661
|
|
662
|
+
@staticmethod
|
663
|
+
def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
|
664
|
+
"""Get a list of strings to append to the prompt.
|
665
|
+
|
666
|
+
Arguments:
|
667
|
+
block: The `TranslatedCodeBlock` to save to a file.
|
668
|
+
"""
|
669
|
+
return [(key, item) for key, item in block.context_tags.items()]
|
670
|
+
|
671
|
+
def _make_prompt_additions(self, block: CodeBlock):
|
672
|
+
# Prepare the additional context to prepend
|
673
|
+
additional_context = "".join(
|
674
|
+
[
|
675
|
+
f"{context_tag}: {context}\n"
|
676
|
+
for context_tag, context in self._get_prompt_additions(block)
|
677
|
+
]
|
678
|
+
)
|
679
|
+
|
680
|
+
# Iterate through existing messages to find and update the system message
|
681
|
+
for i, message in enumerate(self._prompt.messages):
|
682
|
+
if isinstance(message, SystemMessagePromptTemplate):
|
683
|
+
# Prepend the additional context to the system message
|
684
|
+
updated_system_message = SystemMessagePromptTemplate.from_template(
|
685
|
+
additional_context + message.prompt.template
|
686
|
+
)
|
687
|
+
# Directly modify the message in the list
|
688
|
+
self._prompt.messages[i] = updated_system_message
|
689
|
+
break # Assuming there's only one system message to update
|
690
|
+
|
651
691
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
652
692
|
"""Save a file to disk.
|
653
693
|
|
janus/converter/requirements.py
CHANGED
@@ -22,6 +22,11 @@ class RequirementsDocumenter(Documenter):
|
|
22
22
|
self._combiner = ChunkCombiner()
|
23
23
|
self._parser = RequirementsParser()
|
24
24
|
|
25
|
+
@staticmethod
|
26
|
+
def get_prompt_replacements(block) -> dict[str, str]:
|
27
|
+
prompt_replacements: dict[str, str] = {"SOURCE_CODE": block.original.text}
|
28
|
+
return prompt_replacements
|
29
|
+
|
25
30
|
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
26
31
|
"""Save a file to disk.
|
27
32
|
|
janus/language/alc/alc.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Optional
|
3
|
+
|
1
4
|
from langchain.schema.language_model import BaseLanguageModel
|
2
5
|
|
3
6
|
from janus.language.block import CodeBlock
|
@@ -61,7 +64,11 @@ class AlcSplitter(TreeSitterSplitter):
|
|
61
64
|
# next csect or dsect instruction
|
62
65
|
sects: list[list[CodeBlock]] = [[]]
|
63
66
|
for c in block.children:
|
64
|
-
if c.node_type
|
67
|
+
if c.node_type == "csect_instruction":
|
68
|
+
c.context_tags["alc_section"] = "CSECT"
|
69
|
+
sects.append([c])
|
70
|
+
elif c.node_type == "dsect_instruction":
|
71
|
+
c.context_tags["alc_section"] = "DSECT"
|
65
72
|
sects.append([c])
|
66
73
|
else:
|
67
74
|
sects[-1].append(c)
|
@@ -85,3 +92,94 @@ class AlcSplitter(TreeSitterSplitter):
|
|
85
92
|
queue.extend(block.children)
|
86
93
|
|
87
94
|
return root
|
95
|
+
|
96
|
+
|
97
|
+
class AlcListingSplitter(AlcSplitter):
|
98
|
+
"""A class for splitting ALC listing code into functional blocks to
|
99
|
+
prompt with for transcoding.
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
model: None | BaseLanguageModel = None,
|
105
|
+
max_tokens: int = 4096,
|
106
|
+
protected_node_types: tuple[str, ...] = (),
|
107
|
+
prune_node_types: tuple[str, ...] = (),
|
108
|
+
prune_unprotected: bool = False,
|
109
|
+
):
|
110
|
+
"""Initialize a AlcSplitter instance.
|
111
|
+
|
112
|
+
|
113
|
+
Arguments:
|
114
|
+
max_tokens: The maximum number of tokens supported by the model
|
115
|
+
"""
|
116
|
+
# The string to mark the end of the listing header
|
117
|
+
self.header_indicator_str: str = (
|
118
|
+
"Loc Object Code Addr1 Addr2 Stmt Source Statement"
|
119
|
+
)
|
120
|
+
# How many characters to trim from the right side to remove the address column
|
121
|
+
self.address_column_chars: int = 10
|
122
|
+
# The string to mark the end of the left margin
|
123
|
+
self.left_margin_indicator_str: str = "Stmt"
|
124
|
+
super().__init__(
|
125
|
+
model=model,
|
126
|
+
max_tokens=max_tokens,
|
127
|
+
protected_node_types=protected_node_types,
|
128
|
+
prune_node_types=prune_node_types,
|
129
|
+
prune_unprotected=prune_unprotected,
|
130
|
+
)
|
131
|
+
|
132
|
+
def _get_ast(self, code: str) -> CodeBlock:
|
133
|
+
active_usings = self.get_active_usings(code)
|
134
|
+
code = self.preproccess_assembly(code)
|
135
|
+
ast: CodeBlock = super()._get_ast(code)
|
136
|
+
ast.context_tags["active_usings"] = active_usings
|
137
|
+
return ast
|
138
|
+
|
139
|
+
def preproccess_assembly(self, code: str) -> str:
|
140
|
+
"""Remove non-essential lines from an assembly snippet"""
|
141
|
+
|
142
|
+
lines = code.splitlines()
|
143
|
+
lines = self.strip_header_and_left(lines)
|
144
|
+
lines = self.strip_addresses(lines)
|
145
|
+
return "".join(str(line) for line in lines)
|
146
|
+
|
147
|
+
def get_active_usings(self, code: str) -> Optional[str]:
|
148
|
+
"""Look for 'active usings' in the ALC listing header"""
|
149
|
+
lines = code.splitlines()
|
150
|
+
for line in lines:
|
151
|
+
if "Active Usings:" in line:
|
152
|
+
return line.split("Active Usings:")[1]
|
153
|
+
return None
|
154
|
+
|
155
|
+
def strip_header_and_left(
|
156
|
+
self,
|
157
|
+
lines: list[str],
|
158
|
+
) -> list[str]:
|
159
|
+
"""Remove the header and the left panel from the assembly sample"""
|
160
|
+
|
161
|
+
esd_regex = re.compile(f".*{self.header_indicator_str}.*")
|
162
|
+
|
163
|
+
header_end_index: int = [
|
164
|
+
i for i, item in enumerate(lines) if re.search(esd_regex, item)
|
165
|
+
][0]
|
166
|
+
|
167
|
+
left_content_end_column = lines[header_end_index].find(
|
168
|
+
self.left_margin_indicator_str
|
169
|
+
)
|
170
|
+
hori_output_lines = lines[(header_end_index + 1) :]
|
171
|
+
|
172
|
+
left_output_lines = [
|
173
|
+
line[left_content_end_column + 5 :] for line in hori_output_lines
|
174
|
+
]
|
175
|
+
return left_output_lines
|
176
|
+
|
177
|
+
def strip_addresses(self, lines: list[str]) -> list[str]:
|
178
|
+
"""Strip the addresses which run down the right side of the assembly snippet"""
|
179
|
+
|
180
|
+
stripped_lines = [line[: -self.address_column_chars] for line in lines]
|
181
|
+
return stripped_lines
|
182
|
+
|
183
|
+
def strip_footer(self, lines: list[str]):
|
184
|
+
"""Strip the footer from the assembly snippet"""
|
185
|
+
return NotImplementedError
|
janus/language/block.py
CHANGED
@@ -45,6 +45,7 @@ class CodeBlock:
|
|
45
45
|
children: list[ForwardRef("CodeBlock")],
|
46
46
|
embedding_id: Optional[str] = None,
|
47
47
|
affixes: Tuple[str, str] = ("", ""),
|
48
|
+
context_tags: dict[str, str] = {},
|
48
49
|
) -> None:
|
49
50
|
self.id: Hashable = id
|
50
51
|
self.name: Optional[str] = name
|
@@ -59,6 +60,7 @@ class CodeBlock:
|
|
59
60
|
self.children: list[ForwardRef("CodeBlock")] = sorted(children)
|
60
61
|
self.embedding_id: Optional[str] = embedding_id
|
61
62
|
self.affixes: Tuple[str, str] = affixes
|
63
|
+
self.context_tags: dict[str, str] = context_tags
|
62
64
|
|
63
65
|
self.complete = True
|
64
66
|
self.omit_prefix = True
|
@@ -1,12 +1,24 @@
|
|
1
|
-
from janus.language.alc.alc import AlcSplitter
|
1
|
+
from janus.language.alc.alc import AlcListingSplitter, AlcSplitter
|
2
2
|
from janus.language.mumps.mumps import MumpsSplitter
|
3
3
|
from janus.language.naive.registry import register_splitter
|
4
|
+
from janus.language.splitter import Splitter
|
4
5
|
from janus.language.treesitter import TreeSitterSplitter
|
5
6
|
from janus.utils.enums import LANGUAGES
|
7
|
+
from janus.utils.logger import create_logger
|
8
|
+
|
9
|
+
log = create_logger(__name__)
|
6
10
|
|
7
11
|
|
8
12
|
@register_splitter("ast-flex")
|
9
|
-
def get_flexible_ast(language: str, **kwargs):
|
13
|
+
def get_flexible_ast(language: str, **kwargs) -> Splitter:
|
14
|
+
"""Get a flexible AST splitter for the given language.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
language: The language to get the splitter for.
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
A flexible AST splitter for the given language.
|
21
|
+
"""
|
10
22
|
if language == "ibmhlasm":
|
11
23
|
return AlcSplitter(**kwargs)
|
12
24
|
elif language == "mumps":
|
@@ -16,7 +28,17 @@ def get_flexible_ast(language: str, **kwargs):
|
|
16
28
|
|
17
29
|
|
18
30
|
@register_splitter("ast-strict")
|
19
|
-
def get_strict_ast(language: str, **kwargs):
|
31
|
+
def get_strict_ast(language: str, **kwargs) -> Splitter:
|
32
|
+
"""Get a strict AST splitter for the given language.
|
33
|
+
|
34
|
+
The strict splitter will only return nodes that are of a functional type.
|
35
|
+
|
36
|
+
Arguments:
|
37
|
+
language: The language to get the splitter for.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
A strict AST splitter for the given language.
|
41
|
+
"""
|
20
42
|
kwargs.update(
|
21
43
|
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
22
44
|
prune_unprotected=True,
|
@@ -27,3 +49,45 @@ def get_strict_ast(language: str, **kwargs):
|
|
27
49
|
return MumpsSplitter(**kwargs)
|
28
50
|
else:
|
29
51
|
return TreeSitterSplitter(language=language, **kwargs)
|
52
|
+
|
53
|
+
|
54
|
+
@register_splitter("ast-strict-listing")
|
55
|
+
def get_strict_listing_ast(language: str, **kwargs) -> Splitter:
|
56
|
+
"""Get a strict AST splitter for the given language. This splitter is intended for
|
57
|
+
use with IBM HLASM.
|
58
|
+
|
59
|
+
The strict splitter will only return nodes that are of a functional type.
|
60
|
+
|
61
|
+
Arguments:
|
62
|
+
language: The language to get the splitter for.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
A strict AST splitter for the given language.
|
66
|
+
"""
|
67
|
+
kwargs.update(
|
68
|
+
protected_node_types=LANGUAGES[language]["functional_node_types"],
|
69
|
+
prune_unprotected=True,
|
70
|
+
)
|
71
|
+
if language == "ibmhlasm":
|
72
|
+
return AlcListingSplitter(**kwargs)
|
73
|
+
else:
|
74
|
+
log.warning("Listing splitter is only intended for use with IBMHLASM!")
|
75
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
76
|
+
|
77
|
+
|
78
|
+
@register_splitter("ast-flex-listing")
|
79
|
+
def get_flexible_listing_ast(language: str, **kwargs) -> Splitter:
|
80
|
+
"""Get a flexible AST splitter for the given language. This splitter is intended for
|
81
|
+
use with IBM HLASM.
|
82
|
+
|
83
|
+
Arguments:
|
84
|
+
language: The language to get the splitter for.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
A flexible AST splitter for the given language.
|
88
|
+
"""
|
89
|
+
if language == "ibmhlasm":
|
90
|
+
return AlcListingSplitter(**kwargs)
|
91
|
+
else:
|
92
|
+
log.warning("Listing splitter is only intended for use with IBMHLASM!")
|
93
|
+
return TreeSitterSplitter(language=language, **kwargs)
|
janus/parsers/refiner_parser.py
CHANGED
@@ -40,7 +40,9 @@ class RefinerParser(BaseOutputParser):
|
|
40
40
|
return self.parser.parse(text)
|
41
41
|
except OutputParserException as oe:
|
42
42
|
err = str(oe)
|
43
|
-
new_prompt, prompt_arguments = self.refiner.refine(
|
43
|
+
new_prompt, prompt_arguments = self.refiner.refine(
|
44
|
+
self.initial_prompt, last_prompt, text, err
|
45
|
+
)
|
44
46
|
new_chain = new_prompt | self.llm
|
45
47
|
text = new_chain.invoke(prompt_arguments)
|
46
48
|
last_prompt = new_prompt.format(**prompt_arguments)
|
janus/refiners/refiner.py
CHANGED
@@ -5,7 +5,12 @@ from janus.llm.models_info import MODEL_PROMPT_ENGINES
|
|
5
5
|
|
6
6
|
class Refiner:
|
7
7
|
def refine(
|
8
|
-
self,
|
8
|
+
self,
|
9
|
+
original_prompt: str,
|
10
|
+
previous_prompt: str,
|
11
|
+
previous_output: str,
|
12
|
+
errors: str,
|
13
|
+
**kwargs,
|
9
14
|
) -> tuple[ChatPromptTemplate, dict[str, str]]:
|
10
15
|
"""Creates a new prompt based on feedback from original results
|
11
16
|
|
@@ -24,22 +29,27 @@ class BasicRefiner(Refiner):
|
|
24
29
|
def __init__(
|
25
30
|
self,
|
26
31
|
prompt_name: str,
|
27
|
-
|
32
|
+
model_id: str,
|
28
33
|
source_language: str,
|
29
34
|
) -> None:
|
30
35
|
"""Basic refiner, asks llm to fix output of previous prompt given errors
|
31
36
|
|
32
37
|
Arguments:
|
33
38
|
prompt_name: refinement prompt name to use
|
34
|
-
|
39
|
+
model_id: ID of the llm to use. Found in models_info.py
|
35
40
|
source_language: source_langauge to use
|
36
41
|
"""
|
37
42
|
self._prompt_name = prompt_name
|
38
|
-
self.
|
43
|
+
self._model_id = model_id
|
39
44
|
self._source_language = source_language
|
40
45
|
|
41
46
|
def refine(
|
42
|
-
self,
|
47
|
+
self,
|
48
|
+
original_prompt: str,
|
49
|
+
previous_prompt: str,
|
50
|
+
previous_output: str,
|
51
|
+
errors: str,
|
52
|
+
**kwargs,
|
43
53
|
) -> tuple[ChatPromptTemplate, dict[str, str]]:
|
44
54
|
"""Creates a new prompt based on feedback from original results
|
45
55
|
|
@@ -51,13 +61,13 @@ class BasicRefiner(Refiner):
|
|
51
61
|
Returns:
|
52
62
|
Tuple of new prompt and prompt arguments
|
53
63
|
"""
|
54
|
-
prompt_engine = MODEL_PROMPT_ENGINES[self.
|
64
|
+
prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
|
55
65
|
prompt_template=self._prompt_name,
|
56
66
|
source_language=self._source_language,
|
57
67
|
)
|
58
68
|
prompt_arguments = {
|
59
69
|
"ORIGINAL_PROMPT": original_prompt,
|
60
|
-
"OUTPUT":
|
70
|
+
"OUTPUT": previous_output,
|
61
71
|
"ERRORS": errors,
|
62
72
|
}
|
63
73
|
return prompt_engine.prompt, prompt_arguments
|
@@ -1,17 +1,17 @@
|
|
1
|
-
janus/__init__.py,sha256=
|
1
|
+
janus/__init__.py,sha256=XGgjz8H0Qo2cAkx4nww-GdVfdiM8hrSJfv6mGsEwy1s,361
|
2
2
|
janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
|
3
3
|
janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
|
5
5
|
janus/_tests/test_cli.py,sha256=oYJsUGWfpBJWEGRG5NGxdJedU5DU_m6fwJ7xEbJVYl0,4244
|
6
|
-
janus/cli.py,sha256=
|
6
|
+
janus/cli.py,sha256=_92FvDV4qza0nSmyiXqacYxyo1gY6IPwD4gCm6kZfqI,33213
|
7
7
|
janus/converter/__init__.py,sha256=U2EOMcCykiC0ZqhorNefOP_04hOF18qhYoPKrVp1Vrk,345
|
8
8
|
janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
janus/converter/_tests/test_translate.py,sha256=yzcFEGc_z8QmBBBmC9dZnfL9tT8CD1rkpc8Hz44Jp4c,5631
|
10
|
-
janus/converter/converter.py,sha256=
|
10
|
+
janus/converter/converter.py,sha256=CgaNjQE6bz8qCpoYTEwuON40clOq-8r-lp2C173xS0E,27422
|
11
11
|
janus/converter/diagram.py,sha256=5mo1H3Y1uIBPYdIsWz9kxluN5DNyuUMZrtcJmGF2Uw0,5335
|
12
12
|
janus/converter/document.py,sha256=hsW512veNjFWbdl5WriuUdNmMEqZy8ktRvqn9rRmA6E,4566
|
13
13
|
janus/converter/evaluate.py,sha256=APWQUY3gjAXqkJkPzvj0UA4wPK3Cv9QSJLM-YK9t-ng,476
|
14
|
-
janus/converter/requirements.py,sha256=
|
14
|
+
janus/converter/requirements.py,sha256=9tvQ40FZJtG8niIFn45gPQCgKKHVPPoFLinBv6RAqO4,2027
|
15
15
|
janus/converter/translate.py,sha256=0brQTlSfBYmXtoM8QYIOiyr0LrTr0S1n68Du-BR7_WQ,4236
|
16
16
|
janus/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
janus/embedding/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -29,13 +29,13 @@ janus/language/_tests/test_splitter.py,sha256=Hqexa39LLEXlK3ZUw7Zot4PUIACvye2vkq
|
|
29
29
|
janus/language/alc/__init__.py,sha256=j7vOMGhT1Vri6p8dsjSaY-fkO5uFn0sJ0nrNGGvcizM,42
|
30
30
|
janus/language/alc/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
janus/language/alc/_tests/test_alc.py,sha256=NgVeOctm9zf-S328DdUNn9et_-lK1t5O0O2FKElb91Q,1027
|
32
|
-
janus/language/alc/alc.py,sha256=
|
32
|
+
janus/language/alc/alc.py,sha256=l1p6zwyE7ZzY9rnjsUZGuQW41hijFrzxnXnFOpfGq8k,6590
|
33
33
|
janus/language/binary/__init__.py,sha256=AlNAe12ZA366kcGSrQ1FJyOdbwxFqGBFkYR2K6yL818,51
|
34
34
|
janus/language/binary/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
35
|
janus/language/binary/_tests/test_binary.py,sha256=SDdI6tsQj9yXle7wBsksHuKULLMHv7mNgUkDx1nCvpw,1733
|
36
36
|
janus/language/binary/binary.py,sha256=jcc-LZx8Ss-g4j0a691U2yVJV6JGV_zpE_6y_aDq4Cw,6579
|
37
37
|
janus/language/binary/reveng/decompile_script.py,sha256=veW51oJzuO-4UD3Er062jXZ_FYtTFo9OCkl82Z2xr6A,2182
|
38
|
-
janus/language/block.py,sha256=
|
38
|
+
janus/language/block.py,sha256=2rjAYUosHFfWRgLnzf50uAgTMST4Md9Kx6JrlUfEfX4,9398
|
39
39
|
janus/language/combine.py,sha256=Wtve06fa-_Wjv_V5RIf1Nfmg0UxcOEtFNj4vVHpSNbo,2940
|
40
40
|
janus/language/file.py,sha256=jy-cReAoI6F97TXR5bbhPyt8XyUZCdFYnVboubDA_y4,571
|
41
41
|
janus/language/mumps/__init__.py,sha256=-Ou_wJ-JgHezfp1dub2_qCYNiK9wO-zo2MlqxM9qiwE,48
|
@@ -47,7 +47,7 @@ janus/language/naive/__init__.py,sha256=_Gq4inONyVYxe8WLB59d_69kqGbtF40BGKoJPnK4
|
|
47
47
|
janus/language/naive/basic_splitter.py,sha256=RM9pJK2YkHfb6_EFEV-dh_rLqkjS6v0cn3ASPf8A6Fg,459
|
48
48
|
janus/language/naive/chunk_splitter.py,sha256=ebRSbaJhDW-Hyr5__ukbdmAl6kQ1WWFqrq_SfCgHo6k,772
|
49
49
|
janus/language/naive/registry.py,sha256=8YQX1q0IdAm7t69-oC_00I-vfkdRnHuX-OD3KEjEIuU,294
|
50
|
-
janus/language/naive/simple_ast.py,sha256=
|
50
|
+
janus/language/naive/simple_ast.py,sha256=T53UwAZyRfePXzpiNUhe4FyDev6YcX1dVDxYkTcDRPE,3032
|
51
51
|
janus/language/naive/tag_splitter.py,sha256=IXWMn9tBVUGAtzvQi89GhoZ6g7fPXk5MzO0kMCr2mb0,2045
|
52
52
|
janus/language/node.py,sha256=baoYFtapwBQqBtUN6EvHFYRkbR-EcEw1b3fQvH9zIAM,204
|
53
53
|
janus/language/splitter.py,sha256=hITFp4a9bJ6sP74AWvC2GWa2Poo10MyHXYTj9hviXss,16970
|
@@ -87,12 +87,12 @@ janus/parsers/_tests/test_code_parser.py,sha256=RVgMmLvg8_57g0uJphfX-jZZsyBqOSuG
|
|
87
87
|
janus/parsers/code_parser.py,sha256=SZBsYThG4iszKlu4fHoWrs-6cbJiUFjWv4cLSr5bzDM,1790
|
88
88
|
janus/parsers/doc_parser.py,sha256=bJiOE5M7npUZur_1MWJ14C2HZl7-yXExqRXiC5ZBJvI,5679
|
89
89
|
janus/parsers/eval_parser.py,sha256=L1Lu2aNimcqUshe0FQee_9Zqj1rrqyZPXCgEAS05VJ4,2740
|
90
|
-
janus/parsers/refiner_parser.py,sha256=
|
90
|
+
janus/parsers/refiner_parser.py,sha256=5zGoPZyttfRw3kXYDKHId2nhVvAcv1QsvpDFRpe-Few,1680
|
91
91
|
janus/parsers/reqs_parser.py,sha256=6YzpF63rjuDPqpKWfYvtjpsluWQ-UboWlsKoGrGQogA,2380
|
92
92
|
janus/parsers/uml.py,sha256=ZRyGY8YxvYibacTd-WZEAAaW3XjmvJhPJE3o29f71t8,1825
|
93
93
|
janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
94
94
|
janus/prompts/prompt.py,sha256=3796YXIzzIec9b0iUzd8VZlq-AdQbzq8qUGXLy4KH-0,10586
|
95
|
-
janus/refiners/refiner.py,sha256=
|
95
|
+
janus/refiners/refiner.py,sha256=GkV4oUSCrLAhyDJY2aY_Jt8PRF3sC6-bv58YbL2PaNk,2227
|
96
96
|
janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
97
97
|
janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
98
98
|
janus/utils/_tests/test_logger.py,sha256=jkkvrCTKwsFCsZtmyuvc-WJ0rC7LJi2Z91sIe4IiKzA,2209
|
@@ -100,8 +100,8 @@ janus/utils/_tests/test_progress.py,sha256=Rs_u5PiGjP-L-o6C1fhwfE1ig8jYu9Xo9s4p8
|
|
100
100
|
janus/utils/enums.py,sha256=AoilbdiYyMvY2Mp0AM4xlbLSELfut2XMwhIM1S_msP4,27610
|
101
101
|
janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
|
102
102
|
janus/utils/progress.py,sha256=PIpcQec7SrhsfqB25LHj2CDDkfm9umZx90d9LZnAx6k,1469
|
103
|
-
janus_llm-3.
|
104
|
-
janus_llm-3.
|
105
|
-
janus_llm-3.
|
106
|
-
janus_llm-3.
|
107
|
-
janus_llm-3.
|
103
|
+
janus_llm-3.4.1.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
|
104
|
+
janus_llm-3.4.1.dist-info/METADATA,sha256=lQxVgmdepmQHUey6HX2z0wTUEg6XqpaCu91kWpbeq_E,4184
|
105
|
+
janus_llm-3.4.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
106
|
+
janus_llm-3.4.1.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
|
107
|
+
janus_llm-3.4.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|