janus-llm 4.1.0__py3-none-any.whl → 4.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +136 -25
- janus/converter/__init__.py +1 -0
- janus/converter/converter.py +45 -47
- janus/converter/partition.py +27 -0
- janus/language/combine.py +22 -0
- janus/llm/models_info.py +3 -0
- janus/parsers/partition_parser.py +136 -0
- janus/refiners/refiner.py +8 -12
- janus/refiners/uml.py +33 -0
- janus/retrievers/retriever.py +60 -0
- janus/utils/pdf_docs_reader.py +134 -0
- {janus_llm-4.1.0.dist-info → janus_llm-4.2.0.dist-info}/METADATA +9 -1
- {janus_llm-4.1.0.dist-info → janus_llm-4.2.0.dist-info}/RECORD +17 -13
- {janus_llm-4.1.0.dist-info → janus_llm-4.2.0.dist-info}/WHEEL +1 -1
- {janus_llm-4.1.0.dist-info → janus_llm-4.2.0.dist-info}/LICENSE +0 -0
- {janus_llm-4.1.0.dist-info → janus_llm-4.2.0.dist-info}/entry_points.txt +0 -0
janus/__init__.py
CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
5
5
|
from janus.converter.translate import Translator
|
6
6
|
from janus.metrics import * # noqa: F403
|
7
7
|
|
8
|
-
__version__ = "4.
|
8
|
+
__version__ = "4.2.0"
|
9
9
|
|
10
10
|
# Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
|
11
11
|
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
janus/cli.py
CHANGED
@@ -13,10 +13,13 @@ from rich.console import Console
|
|
13
13
|
from rich.prompt import Confirm
|
14
14
|
from typing_extensions import Annotated
|
15
15
|
|
16
|
+
import janus.refiners.refiner
|
17
|
+
import janus.refiners.uml
|
16
18
|
from janus.converter.aggregator import Aggregator
|
17
19
|
from janus.converter.converter import Converter
|
18
20
|
from janus.converter.diagram import DiagramGenerator
|
19
21
|
from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
|
22
|
+
from janus.converter.partition import Partitioner
|
20
23
|
from janus.converter.requirements import RequirementsDocumenter
|
21
24
|
from janus.converter.translate import Translator
|
22
25
|
from janus.embedding.collections import Collections
|
@@ -44,7 +47,6 @@ from janus.llm.models_info import (
|
|
44
47
|
openai_models,
|
45
48
|
)
|
46
49
|
from janus.metrics.cli import evaluate
|
47
|
-
from janus.refiners.refiner import REFINERS
|
48
50
|
from janus.utils.enums import LANGUAGES
|
49
51
|
from janus.utils.logger import create_logger
|
50
52
|
|
@@ -69,6 +71,18 @@ with open(db_file, "r") as f:
|
|
69
71
|
collections_config_file = Path(db_loc) / "collections.json"
|
70
72
|
|
71
73
|
|
74
|
+
def get_subclasses(cls):
|
75
|
+
return set(cls.__subclasses__()).union(
|
76
|
+
set(s for c in cls.__subclasses__() for s in get_subclasses(c))
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
REFINER_TYPES = get_subclasses(janus.refiners.refiner.JanusRefiner).union(
|
81
|
+
{janus.refiners.refiner.JanusRefiner}
|
82
|
+
)
|
83
|
+
REFINERS = {r.__name__: r for r in REFINER_TYPES}
|
84
|
+
|
85
|
+
|
72
86
|
def get_collections_config():
|
73
87
|
if collections_config_file.exists():
|
74
88
|
with open(collections_config_file, "r") as f:
|
@@ -244,22 +258,23 @@ def translate(
|
|
244
258
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
245
259
|
),
|
246
260
|
] = "file",
|
247
|
-
|
248
|
-
str,
|
261
|
+
refiner_types: Annotated[
|
262
|
+
list[str],
|
249
263
|
typer.Option(
|
250
264
|
"-r",
|
251
265
|
"--refiner",
|
252
|
-
help="
|
266
|
+
help="List of refiner types to use. Add -r for each refiner to use in\
|
267
|
+
refinement chain",
|
253
268
|
click_type=click.Choice(list(REFINERS.keys())),
|
254
269
|
),
|
255
|
-
] = "
|
270
|
+
] = ["JanusRefiner"],
|
256
271
|
retriever_type: Annotated[
|
257
272
|
str,
|
258
273
|
typer.Option(
|
259
274
|
"-R",
|
260
275
|
"--retriever",
|
261
276
|
help="Name of custom retriever to use",
|
262
|
-
click_type=click.Choice(["active_usings"]),
|
277
|
+
click_type=click.Choice(["active_usings", "language_docs"]),
|
263
278
|
),
|
264
279
|
] = None,
|
265
280
|
max_tokens: Annotated[
|
@@ -272,6 +287,7 @@ def translate(
|
|
272
287
|
),
|
273
288
|
] = None,
|
274
289
|
):
|
290
|
+
refiner_types = [REFINERS[r] for r in refiner_types]
|
275
291
|
try:
|
276
292
|
target_language, target_version = target_lang.split("-")
|
277
293
|
except ValueError:
|
@@ -296,7 +312,7 @@ def translate(
|
|
296
312
|
db_path=db_loc,
|
297
313
|
db_config=collections_config,
|
298
314
|
splitter_type=splitter_type,
|
299
|
-
|
315
|
+
refiner_types=refiner_types,
|
300
316
|
retriever_type=retriever_type,
|
301
317
|
)
|
302
318
|
translator.translate(input_dir, output_dir, overwrite, collection)
|
@@ -402,22 +418,23 @@ def document(
|
|
402
418
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
403
419
|
),
|
404
420
|
] = "file",
|
405
|
-
|
406
|
-
str,
|
421
|
+
refiner_types: Annotated[
|
422
|
+
list[str],
|
407
423
|
typer.Option(
|
408
424
|
"-r",
|
409
425
|
"--refiner",
|
410
|
-
help="
|
426
|
+
help="List of refiner types to use. Add -r for each refiner to use in\
|
427
|
+
refinement chain",
|
411
428
|
click_type=click.Choice(list(REFINERS.keys())),
|
412
429
|
),
|
413
|
-
] = "
|
430
|
+
] = ["JanusRefiner"],
|
414
431
|
retriever_type: Annotated[
|
415
432
|
str,
|
416
433
|
typer.Option(
|
417
434
|
"-R",
|
418
435
|
"--retriever",
|
419
436
|
help="Name of custom retriever to use",
|
420
|
-
click_type=click.Choice(["active_usings"]),
|
437
|
+
click_type=click.Choice(["active_usings", "language_docs"]),
|
421
438
|
),
|
422
439
|
] = None,
|
423
440
|
max_tokens: Annotated[
|
@@ -430,6 +447,7 @@ def document(
|
|
430
447
|
),
|
431
448
|
] = None,
|
432
449
|
):
|
450
|
+
refiner_types = [REFINERS[r] for r in refiner_types]
|
433
451
|
model_arguments = dict(temperature=temperature)
|
434
452
|
collections_config = get_collections_config()
|
435
453
|
kwargs = dict(
|
@@ -441,7 +459,7 @@ def document(
|
|
441
459
|
db_path=db_loc,
|
442
460
|
db_config=collections_config,
|
443
461
|
splitter_type=splitter_type,
|
444
|
-
|
462
|
+
refiner_types=refiner_types,
|
445
463
|
retriever_type=retriever_type,
|
446
464
|
)
|
447
465
|
if doc_mode == "madlibs":
|
@@ -458,12 +476,6 @@ def document(
|
|
458
476
|
documenter.translate(input_dir, output_dir, overwrite, collection)
|
459
477
|
|
460
478
|
|
461
|
-
def get_subclasses(cls):
|
462
|
-
return set(cls.__subclasses__()).union(
|
463
|
-
set(s for c in cls.__subclasses__() for s in get_subclasses(c))
|
464
|
-
)
|
465
|
-
|
466
|
-
|
467
479
|
@app.command()
|
468
480
|
def aggregate(
|
469
481
|
input_dir: Annotated[
|
@@ -578,6 +590,103 @@ def aggregate(
|
|
578
590
|
aggregator.translate(input_dir, output_dir, overwrite, collection)
|
579
591
|
|
580
592
|
|
593
|
+
@app.command(
|
594
|
+
help="Partition input code using an LLM.",
|
595
|
+
no_args_is_help=True,
|
596
|
+
)
|
597
|
+
def partition(
|
598
|
+
input_dir: Annotated[
|
599
|
+
Path,
|
600
|
+
typer.Option(
|
601
|
+
"--input",
|
602
|
+
"-i",
|
603
|
+
help="The directory containing the source code to be partitioned. ",
|
604
|
+
),
|
605
|
+
],
|
606
|
+
language: Annotated[
|
607
|
+
str,
|
608
|
+
typer.Option(
|
609
|
+
"--language",
|
610
|
+
"-l",
|
611
|
+
help="The language of the source code.",
|
612
|
+
click_type=click.Choice(sorted(LANGUAGES)),
|
613
|
+
),
|
614
|
+
],
|
615
|
+
output_dir: Annotated[
|
616
|
+
Path,
|
617
|
+
typer.Option(
|
618
|
+
"--output-dir", "-o", help="The directory to store the partitioned code in."
|
619
|
+
),
|
620
|
+
],
|
621
|
+
llm_name: Annotated[
|
622
|
+
str,
|
623
|
+
typer.Option(
|
624
|
+
"--llm",
|
625
|
+
"-L",
|
626
|
+
help="The custom name of the model set with 'janus llm add'.",
|
627
|
+
),
|
628
|
+
] = "gpt-4o",
|
629
|
+
max_prompts: Annotated[
|
630
|
+
int,
|
631
|
+
typer.Option(
|
632
|
+
"--max-prompts",
|
633
|
+
"-m",
|
634
|
+
help="The maximum number of times to prompt a model on one functional block "
|
635
|
+
"before exiting the application. This is to prevent wasting too much money.",
|
636
|
+
),
|
637
|
+
] = 10,
|
638
|
+
overwrite: Annotated[
|
639
|
+
bool,
|
640
|
+
typer.Option(
|
641
|
+
"--overwrite/--preserve",
|
642
|
+
help="Whether to overwrite existing files in the output directory",
|
643
|
+
),
|
644
|
+
] = False,
|
645
|
+
temperature: Annotated[
|
646
|
+
float,
|
647
|
+
typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
|
648
|
+
] = 0.7,
|
649
|
+
splitter_type: Annotated[
|
650
|
+
str,
|
651
|
+
typer.Option(
|
652
|
+
"-S",
|
653
|
+
"--splitter",
|
654
|
+
help="Name of custom splitter to use",
|
655
|
+
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
656
|
+
),
|
657
|
+
] = "file",
|
658
|
+
max_tokens: Annotated[
|
659
|
+
int,
|
660
|
+
typer.Option(
|
661
|
+
"--max-tokens",
|
662
|
+
"-M",
|
663
|
+
help="The maximum number of tokens the model will take in. "
|
664
|
+
"If unspecificed, model's default max will be used.",
|
665
|
+
),
|
666
|
+
] = None,
|
667
|
+
partition_token_limit: Annotated[
|
668
|
+
int,
|
669
|
+
typer.Option(
|
670
|
+
"--partition-tokens",
|
671
|
+
"-pt",
|
672
|
+
help="The limit on the number of tokens per partition.",
|
673
|
+
),
|
674
|
+
] = 8192,
|
675
|
+
):
|
676
|
+
model_arguments = dict(temperature=temperature)
|
677
|
+
kwargs = dict(
|
678
|
+
model=llm_name,
|
679
|
+
model_arguments=model_arguments,
|
680
|
+
source_language=language,
|
681
|
+
max_prompts=max_prompts,
|
682
|
+
max_tokens=max_tokens,
|
683
|
+
splitter_type=splitter_type,
|
684
|
+
partition_token_limit=partition_token_limit,
|
685
|
+
)
|
686
|
+
partitioner = Partitioner(**kwargs)
|
687
|
+
partitioner.translate(input_dir, output_dir, overwrite)
|
688
|
+
|
689
|
+
|
581
690
|
@app.command(
|
582
691
|
help="Diagram input code using an LLM.",
|
583
692
|
no_args_is_help=True,
|
@@ -667,25 +776,27 @@ def diagram(
|
|
667
776
|
click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
|
668
777
|
),
|
669
778
|
] = "file",
|
670
|
-
|
671
|
-
str,
|
779
|
+
refiner_types: Annotated[
|
780
|
+
list[str],
|
672
781
|
typer.Option(
|
673
782
|
"-r",
|
674
783
|
"--refiner",
|
675
|
-
help="
|
784
|
+
help="List of refiner types to use. Add -r for each refiner to use in\
|
785
|
+
refinement chain",
|
676
786
|
click_type=click.Choice(list(REFINERS.keys())),
|
677
787
|
),
|
678
|
-
] = "
|
788
|
+
] = ["JanusRefiner"],
|
679
789
|
retriever_type: Annotated[
|
680
790
|
str,
|
681
791
|
typer.Option(
|
682
792
|
"-R",
|
683
793
|
"--retriever",
|
684
794
|
help="Name of custom retriever to use",
|
685
|
-
click_type=click.Choice(["active_usings"]),
|
795
|
+
click_type=click.Choice(["active_usings", "language_docs"]),
|
686
796
|
),
|
687
797
|
] = None,
|
688
798
|
):
|
799
|
+
refiner_types = [REFINERS[r] for r in refiner_types]
|
689
800
|
model_arguments = dict(temperature=temperature)
|
690
801
|
collections_config = get_collections_config()
|
691
802
|
diagram_generator = DiagramGenerator(
|
@@ -696,7 +807,7 @@ def diagram(
|
|
696
807
|
db_path=db_loc,
|
697
808
|
db_config=collections_config,
|
698
809
|
splitter_type=splitter_type,
|
699
|
-
|
810
|
+
refiner_types=refiner_types,
|
700
811
|
retriever_type=retriever_type,
|
701
812
|
diagram_type=diagram_type,
|
702
813
|
add_documentation=add_documentation,
|
janus/converter/__init__.py
CHANGED
@@ -2,5 +2,6 @@ from janus.converter.converter import Converter
|
|
2
2
|
from janus.converter.diagram import DiagramGenerator
|
3
3
|
from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
|
4
4
|
from janus.converter.evaluate import Evaluator
|
5
|
+
from janus.converter.partition import Partitioner
|
5
6
|
from janus.converter.requirements import RequirementsDocumenter
|
6
7
|
from janus.converter.translate import Translator
|
janus/converter/converter.py
CHANGED
@@ -6,7 +6,12 @@ from typing import Any
|
|
6
6
|
|
7
7
|
from langchain_core.exceptions import OutputParserException
|
8
8
|
from langchain_core.prompts import ChatPromptTemplate
|
9
|
-
from langchain_core.runnables import
|
9
|
+
from langchain_core.runnables import (
|
10
|
+
Runnable,
|
11
|
+
RunnableLambda,
|
12
|
+
RunnableParallel,
|
13
|
+
RunnablePassthrough,
|
14
|
+
)
|
10
15
|
from openai import BadRequestError, RateLimitError
|
11
16
|
from pydantic import ValidationError
|
12
17
|
|
@@ -23,15 +28,14 @@ from janus.language.splitter import (
|
|
23
28
|
from janus.llm.model_callbacks import get_model_callback
|
24
29
|
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
|
25
30
|
from janus.parsers.parser import GenericParser, JanusParser
|
26
|
-
from janus.refiners.refiner import
|
27
|
-
FixParserExceptions,
|
28
|
-
HallucinationRefiner,
|
29
|
-
JanusRefiner,
|
30
|
-
ReflectionRefiner,
|
31
|
-
)
|
31
|
+
from janus.refiners.refiner import JanusRefiner
|
32
32
|
|
33
33
|
# from janus.refiners.refiner import BasicRefiner, Refiner
|
34
|
-
from janus.retrievers.retriever import
|
34
|
+
from janus.retrievers.retriever import (
|
35
|
+
ActiveUsingsRetriever,
|
36
|
+
JanusRetriever,
|
37
|
+
LanguageDocsRetriever,
|
38
|
+
)
|
35
39
|
from janus.utils.enums import LANGUAGES
|
36
40
|
from janus.utils.logger import create_logger
|
37
41
|
|
@@ -78,7 +82,7 @@ class Converter:
|
|
78
82
|
protected_node_types: tuple[str, ...] = (),
|
79
83
|
prune_node_types: tuple[str, ...] = (),
|
80
84
|
splitter_type: str = "file",
|
81
|
-
|
85
|
+
refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
|
82
86
|
retriever_type: str | None = None,
|
83
87
|
) -> None:
|
84
88
|
"""Initialize a Converter instance.
|
@@ -105,6 +109,7 @@ class Converter:
|
|
105
109
|
- None
|
106
110
|
retriever_type: The type of retriever to use. Valid values:
|
107
111
|
- "active_usings"
|
112
|
+
- "language_docs"
|
108
113
|
- None
|
109
114
|
"""
|
110
115
|
self._changed_attrs: set = set()
|
@@ -133,10 +138,11 @@ class Converter:
|
|
133
138
|
self._prompt: ChatPromptTemplate
|
134
139
|
|
135
140
|
self._parser: JanusParser = GenericParser()
|
141
|
+
self._base_parser: JanusParser = GenericParser()
|
136
142
|
self._combiner: Combiner = Combiner()
|
137
143
|
|
138
144
|
self._splitter_type: str
|
139
|
-
self.
|
145
|
+
self._refiner_types: list[type[JanusRefiner]]
|
140
146
|
self._retriever_type: str | None
|
141
147
|
|
142
148
|
self._splitter: Splitter
|
@@ -144,7 +150,7 @@ class Converter:
|
|
144
150
|
self._retriever: JanusRetriever
|
145
151
|
|
146
152
|
self.set_splitter(splitter_type=splitter_type)
|
147
|
-
self.
|
153
|
+
self.set_refiner_types(refiner_types=refiner_types)
|
148
154
|
self.set_retriever(retriever_type=retriever_type)
|
149
155
|
self.set_model(model_name=model, **model_arguments)
|
150
156
|
self.set_prompt(prompt_template=prompt_template)
|
@@ -170,7 +176,7 @@ class Converter:
|
|
170
176
|
self._load_model()
|
171
177
|
self._load_prompt()
|
172
178
|
self._load_retriever()
|
173
|
-
self.
|
179
|
+
self._load_refiner_chain()
|
174
180
|
self._load_splitter()
|
175
181
|
self._load_vectorizer()
|
176
182
|
self._load_chain()
|
@@ -210,13 +216,13 @@ class Converter:
|
|
210
216
|
|
211
217
|
self._splitter_type = splitter_type
|
212
218
|
|
213
|
-
def
|
219
|
+
def set_refiner_types(self, refiner_types: list[type[JanusRefiner]]) -> None:
|
214
220
|
"""Validate and set the refiner type
|
215
221
|
|
216
222
|
Arguments:
|
217
223
|
refiner_type: the type of refiner to use
|
218
224
|
"""
|
219
|
-
self.
|
225
|
+
self._refiner_types = refiner_types
|
220
226
|
|
221
227
|
def set_retriever(self, retriever_type: str | None) -> None:
|
222
228
|
"""Validate and set the retriever type
|
@@ -355,48 +361,40 @@ class Converter:
|
|
355
361
|
def _load_retriever(self):
|
356
362
|
if self._retriever_type == "active_usings":
|
357
363
|
self._retriever = ActiveUsingsRetriever()
|
364
|
+
elif self._retriever_type == "language_docs":
|
365
|
+
self._retriever = LanguageDocsRetriever(self._llm, self._source_language)
|
358
366
|
else:
|
359
367
|
self._retriever = JanusRetriever()
|
360
368
|
|
361
|
-
@run_if_changed("
|
362
|
-
def
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
max_retries=self.max_prompts,
|
369
|
+
@run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
|
370
|
+
def _load_refiner_chain(self) -> None:
|
371
|
+
self._refiner_chain = RunnableParallel(
|
372
|
+
completion=self._llm,
|
373
|
+
prompt_value=RunnablePassthrough(),
|
374
|
+
)
|
375
|
+
for refiner_type in self._refiner_types[:-1]:
|
376
|
+
# NOTE: Do NOT remove refiner_type=refiner_type from lambda.
|
377
|
+
# Due to lambda capture, must be present or chain will not
|
378
|
+
# be correctly constructed.
|
379
|
+
self._refiner_chain = self._refiner_chain | RunnableParallel(
|
380
|
+
completion=lambda x, refiner_type=refiner_type: refiner_type(
|
381
|
+
llm=self._llm,
|
382
|
+
parser=self._base_parser,
|
383
|
+
max_retries=self.max_prompts,
|
384
|
+
).parse_completion(**x),
|
385
|
+
prompt_value=lambda x: x["prompt_value"],
|
379
386
|
)
|
380
|
-
|
381
|
-
self.
|
387
|
+
self._refiner_chain = self._refiner_chain | RunnableLambda(
|
388
|
+
lambda x: self._refiner_types[-1](
|
382
389
|
llm=self._llm,
|
383
390
|
parser=self._parser,
|
384
391
|
max_retries=self.max_prompts,
|
385
|
-
)
|
386
|
-
|
387
|
-
self._refiner = JanusRefiner(parser=self._parser)
|
392
|
+
).parse_completion(**x)
|
393
|
+
)
|
388
394
|
|
389
|
-
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "
|
395
|
+
@run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
|
390
396
|
def _load_chain(self):
|
391
|
-
self.chain = (
|
392
|
-
self._input_runnable()
|
393
|
-
| self._prompt
|
394
|
-
| RunnableParallel(
|
395
|
-
completion=self._llm,
|
396
|
-
prompt_value=RunnablePassthrough(),
|
397
|
-
)
|
398
|
-
| self._refiner.parse_runnable
|
399
|
-
)
|
397
|
+
self.chain = self._input_runnable() | self._prompt | self._refiner_chain
|
400
398
|
|
401
399
|
def _input_runnable(self) -> Runnable:
|
402
400
|
return RunnableParallel(
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
from janus.converter.converter import Converter
|
4
|
+
from janus.language.block import TranslatedCodeBlock
|
5
|
+
from janus.parsers.partition_parser import PartitionParser
|
6
|
+
from janus.utils.logger import create_logger
|
7
|
+
|
8
|
+
log = create_logger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class Partitioner(Converter):
|
12
|
+
def __init__(self, partition_token_limit: int, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
self.set_prompt("partition")
|
15
|
+
self._load_model()
|
16
|
+
self._parser = PartitionParser(
|
17
|
+
token_limit=partition_token_limit,
|
18
|
+
model=self._llm,
|
19
|
+
)
|
20
|
+
self._target_language = self._source_language
|
21
|
+
self._target_suffix = self._source_suffix
|
22
|
+
self._load_parameters()
|
23
|
+
|
24
|
+
def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
|
25
|
+
output_str = self._parser.parse_combined_output(block.complete_text)
|
26
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
27
|
+
out_path.write_text(output_str, encoding="utf-8")
|
janus/language/combine.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import re
|
2
|
+
|
1
3
|
from janus.language.block import CodeBlock, TranslatedCodeBlock
|
2
4
|
from janus.language.file import FileManager
|
3
5
|
from janus.utils.logger import create_logger
|
@@ -90,3 +92,23 @@ class ChunkCombiner(Combiner):
|
|
90
92
|
root: The functional code block to combine with its children.
|
91
93
|
"""
|
92
94
|
return root
|
95
|
+
|
96
|
+
|
97
|
+
class PartitionCombiner(Combiner):
|
98
|
+
@staticmethod
|
99
|
+
def combine(root: CodeBlock) -> None:
|
100
|
+
"""A combiner which inserts partition tags between code blocks"""
|
101
|
+
queue = [root]
|
102
|
+
while queue:
|
103
|
+
block = queue.pop(0)
|
104
|
+
if block.children:
|
105
|
+
queue.extend(block.children)
|
106
|
+
else:
|
107
|
+
block.affixes = (block.prefix, block.suffix + "\n<JANUS_PARTITION>\n")
|
108
|
+
|
109
|
+
super(PartitionCombiner, PartitionCombiner).combine(root)
|
110
|
+
root.text = re.sub(r"(?:\n<JANUS_PARTITION>\n)+$", "", root.text)
|
111
|
+
root.affixes = (
|
112
|
+
root.prefix,
|
113
|
+
re.sub(r"(?:\n<JANUS_PARTITION>\n)+$", "", root.suffix),
|
114
|
+
)
|
janus/llm/models_info.py
CHANGED
@@ -90,6 +90,7 @@ claude_models = [
|
|
90
90
|
"bedrock-claude-instant-v1",
|
91
91
|
"bedrock-claude-haiku",
|
92
92
|
"bedrock-claude-sonnet",
|
93
|
+
"bedrock-claude-sonnet-3.5",
|
93
94
|
]
|
94
95
|
llama2_models = [
|
95
96
|
"bedrock-llama2-70b",
|
@@ -153,6 +154,7 @@ MODEL_ID_TO_LONG_ID = {
|
|
153
154
|
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
154
155
|
"bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
155
156
|
"bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
157
|
+
"bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
156
158
|
"bedrock-llama2-70b": "meta.llama2-70b-v1",
|
157
159
|
"bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
|
158
160
|
"bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
|
@@ -200,6 +202,7 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
200
202
|
"anthropic.claude-instant-v1": 100_000,
|
201
203
|
"anthropic.claude-3-haiku-20240307-v1:0": 248_000,
|
202
204
|
"anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
|
205
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
|
203
206
|
"meta.llama2-70b-v1": 4096,
|
204
207
|
"meta.llama2-70b-chat-v1": 4096,
|
205
208
|
"meta.llama2-13b-chat-v1": 4096,
|
@@ -0,0 +1,136 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.language_models import BaseLanguageModel
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class PartitionObject(BaseModel):
|
20
|
+
reasoning: str = Field(
|
21
|
+
description="An explanation for why the code should be split at this point"
|
22
|
+
)
|
23
|
+
location: str = Field(
|
24
|
+
description="The 8-character line label which should start a new chunk"
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class PartitionList(BaseModel):
|
29
|
+
__root__: list[PartitionObject] = Field(
|
30
|
+
description=(
|
31
|
+
"A list of appropriate split points, each with a `reasoning` field "
|
32
|
+
"that explains a justification for splitting the code at that point, "
|
33
|
+
"and a `location` field which is simply the 8-character line ID. "
|
34
|
+
"The `reasoning` field should always be included first."
|
35
|
+
)
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
class PartitionParser(JanusParser, PydanticOutputParser):
|
40
|
+
token_limit: int
|
41
|
+
model: BaseLanguageModel
|
42
|
+
lines: list[str] = []
|
43
|
+
line_id_to_index: dict[str, int] = {}
|
44
|
+
|
45
|
+
def __init__(self, token_limit: int, model: BaseLanguageModel):
|
46
|
+
PydanticOutputParser.__init__(
|
47
|
+
self,
|
48
|
+
pydantic_object=PartitionList,
|
49
|
+
model=model,
|
50
|
+
token_limit=token_limit,
|
51
|
+
)
|
52
|
+
|
53
|
+
def parse_input(self, block: CodeBlock) -> str:
|
54
|
+
code = str(block.text)
|
55
|
+
RNG.seed(code)
|
56
|
+
|
57
|
+
self.lines = code.split("\n")
|
58
|
+
|
59
|
+
# Generate a unique ID for each line (ensure they are unique)
|
60
|
+
line_ids = set()
|
61
|
+
while len(line_ids) < len(self.lines):
|
62
|
+
line_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
|
63
|
+
|
64
|
+
# Prepend each line with the corresponding ID, save the mapping
|
65
|
+
self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
|
66
|
+
processed = "\n".join(
|
67
|
+
f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
|
68
|
+
)
|
69
|
+
return processed
|
70
|
+
|
71
|
+
def parse(self, text: str | BaseMessage) -> str:
|
72
|
+
if isinstance(text, BaseMessage):
|
73
|
+
text = str(text.content)
|
74
|
+
|
75
|
+
try:
|
76
|
+
out: PartitionList = super().parse(text)
|
77
|
+
except (OutputParserException, json.JSONDecodeError):
|
78
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
79
|
+
raise
|
80
|
+
|
81
|
+
# Locate any invalid line IDs, raise exception if any found
|
82
|
+
invalid_splits = [
|
83
|
+
partition.location
|
84
|
+
for partition in out.__root__
|
85
|
+
if partition.location not in self.line_id_to_index
|
86
|
+
]
|
87
|
+
if invalid_splits:
|
88
|
+
err_msg = (
|
89
|
+
f"{len(invalid_splits)} line ID(s) not found in input: "
|
90
|
+
+ ", ".join(invalid_splits)
|
91
|
+
)
|
92
|
+
log.warning(err_msg)
|
93
|
+
raise OutputParserException(err_msg)
|
94
|
+
|
95
|
+
# Map line IDs to indices (so they can be sorted and lines indexed)
|
96
|
+
index_to_line_id = {0: "START", None: "END"}
|
97
|
+
split_points = {0}
|
98
|
+
for partition in out.__root__:
|
99
|
+
index = self.line_id_to_index[partition.location]
|
100
|
+
index_to_line_id[index] = partition.location
|
101
|
+
split_points.add(index)
|
102
|
+
|
103
|
+
# Get partition start/ends, chunks, chunk lengths
|
104
|
+
split_points = sorted(split_points) + [None]
|
105
|
+
partition_indices = list(zip(split_points, split_points[1:]))
|
106
|
+
partition_points = [
|
107
|
+
(index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
|
108
|
+
]
|
109
|
+
chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
|
110
|
+
chunk_tokens = list(map(self.model.get_num_tokens, chunks))
|
111
|
+
|
112
|
+
# Collect any chunks that exceed token limit
|
113
|
+
oversized_indices: list[int] = [
|
114
|
+
i for i, n in enumerate(chunk_tokens) if n > self.token_limit
|
115
|
+
]
|
116
|
+
if oversized_indices:
|
117
|
+
data = list(zip(partition_points, chunks, chunk_tokens))
|
118
|
+
data = [data[i] for i in oversized_indices]
|
119
|
+
|
120
|
+
problem_points = "\n".join(
|
121
|
+
[
|
122
|
+
f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
|
123
|
+
for (i0, i1), _, t in data
|
124
|
+
]
|
125
|
+
)
|
126
|
+
log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
|
127
|
+
log.debug(
|
128
|
+
"Oversized chunks:\n"
|
129
|
+
+ "\n#############\n".join(chunk for _, chunk, _ in data)
|
130
|
+
)
|
131
|
+
raise OutputParserException(
|
132
|
+
f"The following segments are too long and must be "
|
133
|
+
f"further subdivided:\n{problem_points}"
|
134
|
+
)
|
135
|
+
|
136
|
+
return "\n<JANUS_PARTITION>\n".join(chunks)
|
janus/refiners/refiner.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Any
|
2
3
|
|
3
4
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
@@ -27,7 +28,7 @@ class JanusRefiner(JanusParser):
|
|
27
28
|
|
28
29
|
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
30
|
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
-
retry_prompt = MODEL_PROMPT_ENGINES[llm.
|
31
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
31
32
|
source_language="text",
|
32
33
|
prompt_template="refinement/fix_exceptions",
|
33
34
|
).prompt
|
@@ -46,6 +47,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
46
47
|
max_retries: int
|
47
48
|
reflection_chain: RunnableSerializable
|
48
49
|
revision_chain: RunnableSerializable
|
50
|
+
reflection_prompt_name: str
|
49
51
|
|
50
52
|
def __init__(
|
51
53
|
self,
|
@@ -54,11 +56,11 @@ class ReflectionRefiner(JanusRefiner):
|
|
54
56
|
max_retries: int,
|
55
57
|
prompt_template_name: str = "refinement/reflection",
|
56
58
|
):
|
57
|
-
reflection_prompt = MODEL_PROMPT_ENGINES[llm.
|
59
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
58
60
|
source_language="text",
|
59
61
|
prompt_template=prompt_template_name,
|
60
62
|
).prompt
|
61
|
-
revision_prompt = MODEL_PROMPT_ENGINES[llm.
|
63
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
62
64
|
source_language="text",
|
63
65
|
prompt_template="refinement/revision",
|
64
66
|
).prompt
|
@@ -66,6 +68,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
66
68
|
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
69
|
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
70
|
super().__init__(
|
71
|
+
reflection_prompt_name=prompt_template_name,
|
69
72
|
reflection_chain=reflection_chain,
|
70
73
|
revision_chain=revision_chain,
|
71
74
|
parser=parser,
|
@@ -75,6 +78,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
75
78
|
def parse_completion(
|
76
79
|
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
80
|
) -> Any:
|
81
|
+
log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
|
78
82
|
for retry_number in range(self.max_retries):
|
79
83
|
reflection = self.reflection_chain.invoke(
|
80
84
|
dict(
|
@@ -82,7 +86,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
82
86
|
completion=completion,
|
83
87
|
)
|
84
88
|
)
|
85
|
-
if
|
89
|
+
if re.search(r"\bLGTM\b", reflection) is not None:
|
86
90
|
return self.parser.parse(completion)
|
87
91
|
if not retry_number:
|
88
92
|
log.info(f"Completion:\n{completion}")
|
@@ -105,11 +109,3 @@ class HallucinationRefiner(ReflectionRefiner):
|
|
105
109
|
prompt_template_name="refinement/hallucination",
|
106
110
|
**kwargs,
|
107
111
|
)
|
108
|
-
|
109
|
-
|
110
|
-
REFINERS = dict(
|
111
|
-
none=JanusRefiner,
|
112
|
-
parser=FixParserExceptions,
|
113
|
-
reflection=ReflectionRefiner,
|
114
|
-
hallucination=HallucinationRefiner,
|
115
|
-
)
|
janus/refiners/uml.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from janus.llm.models_info import JanusModel
|
2
|
+
from janus.parsers.parser import JanusParser
|
3
|
+
from janus.refiners.refiner import ReflectionRefiner
|
4
|
+
|
5
|
+
|
6
|
+
class ALCFixUMLVariablesRefiner(ReflectionRefiner):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
llm: JanusModel,
|
10
|
+
parser: JanusParser,
|
11
|
+
max_retries: int,
|
12
|
+
):
|
13
|
+
super().__init__(
|
14
|
+
llm=llm,
|
15
|
+
parser=parser,
|
16
|
+
max_retries=max_retries,
|
17
|
+
prompt_template_name="refinement/uml/alc_fix_variables",
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class FixUMLConnectionsRefiner(ReflectionRefiner):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
llm: JanusModel,
|
25
|
+
parser: JanusParser,
|
26
|
+
max_retries: int,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
llm=llm,
|
30
|
+
parser=parser,
|
31
|
+
max_retries=max_retries,
|
32
|
+
prompt_template_name="refinement/uml/fix_connections",
|
33
|
+
)
|
janus/retrievers/retriever.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from langchain_core.documents import Document
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
1
5
|
from langchain_core.retrievers import BaseRetriever
|
2
6
|
from langchain_core.runnables import Runnable, RunnableConfig
|
3
7
|
|
4
8
|
from janus.language.block import CodeBlock
|
9
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
10
|
+
from janus.utils.logger import create_logger
|
11
|
+
from janus.utils.pdf_docs_reader import PDFDocsReader
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
5
14
|
|
6
15
|
|
7
16
|
class JanusRetriever(Runnable):
|
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
|
|
40
49
|
docs = self.retriever.invoke(code_block.text)
|
41
50
|
context = "\n\n".join(doc.page_content for doc in docs)
|
42
51
|
return f"You may use the following additional context: {context}"
|
52
|
+
|
53
|
+
|
54
|
+
class LanguageDocsRetriever(JanusRetriever):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
llm: JanusModel,
|
58
|
+
language_name: str,
|
59
|
+
prompt_template_name: str = "retrieval/language_docs",
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.llm: JanusModel = llm
|
63
|
+
self.language: str = language_name
|
64
|
+
|
65
|
+
self.PDF_reader = PDFDocsReader(
|
66
|
+
language=self.language,
|
67
|
+
)
|
68
|
+
|
69
|
+
language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
|
70
|
+
source_language=self.language,
|
71
|
+
prompt_template=prompt_template_name,
|
72
|
+
).prompt
|
73
|
+
|
74
|
+
parser: StrOutputParser = StrOutputParser()
|
75
|
+
self.chain = language_docs_prompt | self.llm | parser
|
76
|
+
|
77
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
78
|
+
functionality_to_reference: str = self.chain.invoke(
|
79
|
+
dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
|
80
|
+
)
|
81
|
+
if functionality_to_reference == "NODOCS":
|
82
|
+
log.debug("No Opcodes requested from language docs retriever.")
|
83
|
+
return ""
|
84
|
+
else:
|
85
|
+
functionality_to_reference: List = functionality_to_reference.split(", ")
|
86
|
+
log.debug(
|
87
|
+
f"List of opcodes requested by language docs retriever"
|
88
|
+
f"to search the {self.language} "
|
89
|
+
f"docs for: {functionality_to_reference}"
|
90
|
+
)
|
91
|
+
|
92
|
+
docs: List[Document] = self.PDF_reader.search_language_reference(
|
93
|
+
functionality_to_reference
|
94
|
+
)
|
95
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
96
|
+
if context:
|
97
|
+
return (
|
98
|
+
f"You may reference the following excerpts from the {self.language} "
|
99
|
+
f"language documentation: {context}"
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
return ""
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import joblib
|
7
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8
|
+
from langchain_core.documents import Document
|
9
|
+
from langchain_unstructured import UnstructuredLoader
|
10
|
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
11
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
12
|
+
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PDFDocsReader:
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
language: str,
|
22
|
+
chunk_size: int = 1000,
|
23
|
+
chunk_overlap: int = 100,
|
24
|
+
start_page: Optional[int] = None,
|
25
|
+
end_page: Optional[int] = None,
|
26
|
+
vectorizer: CountVectorizer = TfidfVectorizer(),
|
27
|
+
):
|
28
|
+
self.retrieval_docs_dir: Path = Path(
|
29
|
+
os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
|
30
|
+
)
|
31
|
+
self.language = language
|
32
|
+
self.chunk_size = chunk_size
|
33
|
+
self.chunk_overlap = chunk_overlap
|
34
|
+
self.start_page = start_page
|
35
|
+
self.end_page = end_page
|
36
|
+
self.vectorizer = vectorizer
|
37
|
+
self.documents = self.load_and_chunk_pdf()
|
38
|
+
self.doc_vectors = self.vectorize_documents()
|
39
|
+
|
40
|
+
def load_and_chunk_pdf(self) -> List[str]:
|
41
|
+
pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
|
42
|
+
pickled_documents_path = (
|
43
|
+
self.retrieval_docs_dir / f"{self.language}_documents.pkl"
|
44
|
+
)
|
45
|
+
|
46
|
+
if pickled_documents_path.exists():
|
47
|
+
log.debug(
|
48
|
+
f"Loading pre-chunked PDF from {pickled_documents_path}. "
|
49
|
+
f"If you want to regenerate retrieval docs for {self.language}, "
|
50
|
+
f"delete the file at {pickled_documents_path}, "
|
51
|
+
f"then add a new {self.language}.pdf."
|
52
|
+
)
|
53
|
+
documents = joblib.load(pickled_documents_path)
|
54
|
+
else:
|
55
|
+
if not pdf_path.exists():
|
56
|
+
raise FileNotFoundError(
|
57
|
+
f"Language docs retrieval is enabled, but no PDF for language "
|
58
|
+
f"'{self.language}' was found. Move a "
|
59
|
+
f"{self.language} reference manual to "
|
60
|
+
f"{pdf_path.absolute()} "
|
61
|
+
f"(the path to the directory of PDF docs can be "
|
62
|
+
f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
|
63
|
+
)
|
64
|
+
log.info(
|
65
|
+
f"Chunking reference PDF for {self.language} using unstructured - "
|
66
|
+
f"if your PDF has many pages, this could take a while..."
|
67
|
+
)
|
68
|
+
start_time = time.time()
|
69
|
+
loader = UnstructuredLoader(
|
70
|
+
pdf_path,
|
71
|
+
chunking_strategy="basic",
|
72
|
+
max_characters=1000000,
|
73
|
+
include_orig_elements=False,
|
74
|
+
start_page=self.start_page,
|
75
|
+
end_page=self.end_page,
|
76
|
+
)
|
77
|
+
docs = loader.load()
|
78
|
+
text = "\n\n".join([doc.page_content for doc in docs])
|
79
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
80
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
81
|
+
)
|
82
|
+
documents = text_splitter.split_text(text)
|
83
|
+
log.info(f"Document store created for language: {self.language}")
|
84
|
+
end_time = time.time()
|
85
|
+
log.info(
|
86
|
+
f"Processing time for {self.language} PDF: "
|
87
|
+
f"{end_time - start_time} seconds"
|
88
|
+
)
|
89
|
+
|
90
|
+
joblib.dump(documents, pickled_documents_path)
|
91
|
+
log.debug(f"Documents saved to {pickled_documents_path}")
|
92
|
+
|
93
|
+
return documents
|
94
|
+
|
95
|
+
def vectorize_documents(self) -> (TfidfVectorizer, any):
|
96
|
+
doc_vectors = self.vectorizer.fit_transform(self.documents)
|
97
|
+
return doc_vectors
|
98
|
+
|
99
|
+
def search_language_reference(
|
100
|
+
self,
|
101
|
+
query: List[str],
|
102
|
+
top_k: int = 1,
|
103
|
+
min_similarity: float = 0.1,
|
104
|
+
) -> List[Document]:
|
105
|
+
"""Searches through the vectorized PDF for the query using
|
106
|
+
tf-idf and returns a list of langchain Documents."""
|
107
|
+
|
108
|
+
docs: List[Document] = []
|
109
|
+
|
110
|
+
for item in query:
|
111
|
+
# Transform the query using the TF-IDF vectorizer
|
112
|
+
query_vector = self.vectorizer.transform([item])
|
113
|
+
|
114
|
+
# Calculate cosine similarities between the query and document vectors
|
115
|
+
similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
|
116
|
+
|
117
|
+
# Get the indices of documents with similarity above the threshold
|
118
|
+
valid_indices = [
|
119
|
+
i for i, sim in enumerate(similarities) if sim >= min_similarity
|
120
|
+
]
|
121
|
+
|
122
|
+
# Sort the valid indices by similarity score in descending order
|
123
|
+
sorted_indices = sorted(
|
124
|
+
valid_indices, key=lambda i: similarities[i], reverse=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Limit to top-k results
|
128
|
+
top_indices = sorted_indices[:top_k]
|
129
|
+
|
130
|
+
# Retrieve the top-k most relevant documents
|
131
|
+
docs += [Document(page_content=self.documents[i]) for i in top_indices]
|
132
|
+
log.debug(f"Langauge documentation search result: {docs}")
|
133
|
+
|
134
|
+
return docs
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: janus-llm
|
3
|
-
Version: 4.
|
3
|
+
Version: 4.2.0
|
4
4
|
Summary: A transcoding library using LLMs.
|
5
5
|
Home-page: https://github.com/janus-llm/janus-llm
|
6
6
|
License: Apache 2.0
|
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
|
|
23
23
|
Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
|
24
24
|
Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
|
25
25
|
Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
|
26
|
+
Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
|
26
27
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
27
28
|
Requires-Dist: numpy (>=1.24.3,<2.0.0)
|
28
29
|
Requires-Dist: openai (>=1.14.0,<2.0.0)
|
30
|
+
Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
|
29
31
|
Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
|
30
32
|
Requires-Dist: py-rouge (>=1.1,<2.0)
|
33
|
+
Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
|
31
34
|
Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
|
32
35
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
33
36
|
Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
|
37
|
+
Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
|
34
38
|
Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
|
39
|
+
Requires-Dist: tesseract (>=0.1.3,<0.2.0)
|
35
40
|
Requires-Dist: text-generation (>=0.6.0,<0.7.0)
|
36
41
|
Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
|
37
42
|
Requires-Dist: transformers (>=4.31.0,<5.0.0)
|
38
43
|
Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
|
39
44
|
Requires-Dist: typer (>=0.9.0,<0.10.0)
|
45
|
+
Requires-Dist: unstructured (>=0.15.9,<0.16.0)
|
46
|
+
Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
|
47
|
+
Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
|
40
48
|
Project-URL: Documentation, https://janus-llm.github.io/janus-llm
|
41
49
|
Project-URL: Repository, https://github.com/janus-llm/janus-llm
|
42
50
|
Description-Content-Type: text/markdown
|
@@ -1,17 +1,18 @@
|
|
1
|
-
janus/__init__.py,sha256=
|
1
|
+
janus/__init__.py,sha256=8ZZh7ctoYQaClu_ak9pFc5eYVEcaSju33Ru0vZBp_iM,361
|
2
2
|
janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
|
3
3
|
janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
|
5
5
|
janus/_tests/test_cli.py,sha256=6ef7h11bg4i7Q6L1-r0ZdcY7YrH4n472kvDiA03T4c8,4275
|
6
|
-
janus/cli.py,sha256=
|
7
|
-
janus/converter/__init__.py,sha256=
|
6
|
+
janus/cli.py,sha256=eGmzu8aei1QNN_WaWeMYltgIHdKr1MPwG2Er0AEBIuo,42563
|
7
|
+
janus/converter/__init__.py,sha256=Jnp3TsJ4M1LWDAzXFSyxzMpygbYOxkR-qYxU-G6Gi1k,395
|
8
8
|
janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
janus/converter/_tests/test_translate.py,sha256=T5CzNrwHqJWfb39Izq84R9WvM3toSlJq31SeA_U7d_4,5641
|
10
10
|
janus/converter/aggregator.py,sha256=MuAXMKmq6PuUo_w6ljyiuDn81Gk2dN-Ci7FVeLc6vhs,1966
|
11
|
-
janus/converter/converter.py,sha256=
|
11
|
+
janus/converter/converter.py,sha256=1WFGy8LozO8pVjbPcYJa9-TTZqgNxwUs7oDca86TcvE,26174
|
12
12
|
janus/converter/diagram.py,sha256=-wktVBPrSBgNIQfHIfa2bJNg6L9CYJQgrr9-xU8DFPw,1646
|
13
13
|
janus/converter/document.py,sha256=qNt2UncMheUBadXCFHGq74tqCrvZub5DCgZpd3Qa54o,4564
|
14
14
|
janus/converter/evaluate.py,sha256=APWQUY3gjAXqkJkPzvj0UA4wPK3Cv9QSJLM-YK9t-ng,476
|
15
|
+
janus/converter/partition.py,sha256=ASvv4hAue44qHobO4kqr_tKr-eJsXCPPdD3NtNd9V-E,993
|
15
16
|
janus/converter/requirements.py,sha256=9tvQ40FZJtG8niIFn45gPQCgKKHVPPoFLinBv6RAqO4,2027
|
16
17
|
janus/converter/translate.py,sha256=S1DPZdmX9Vrn_sJPcobvXmhmS8U53yl5cRXjsmXPtas,4246
|
17
18
|
janus/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -37,7 +38,7 @@ janus/language/binary/_tests/test_binary.py,sha256=cIKIxjj6kIY3rcxLwqUPESP9bxWrH
|
|
37
38
|
janus/language/binary/binary.py,sha256=PHsVa8jcM7sW9aTboGRWXj6ewQznz0kVPNWtP4B9YPU,6555
|
38
39
|
janus/language/binary/reveng/decompile_script.py,sha256=veW51oJzuO-4UD3Er062jXZ_FYtTFo9OCkl82Z2xr6A,2182
|
39
40
|
janus/language/block.py,sha256=2rjAYUosHFfWRgLnzf50uAgTMST4Md9Kx6JrlUfEfX4,9398
|
40
|
-
janus/language/combine.py,sha256=
|
41
|
+
janus/language/combine.py,sha256=egZRl1xZXAFXa2ZjjfqnNckc9uxuo6e1MJgkRrCgvd8,3650
|
41
42
|
janus/language/file.py,sha256=jy-cReAoI6F97TXR5bbhPyt8XyUZCdFYnVboubDA_y4,571
|
42
43
|
janus/language/mumps/__init__.py,sha256=-Ou_wJ-JgHezfp1dub2_qCYNiK9wO-zo2MlqxM9qiwE,48
|
43
44
|
janus/language/mumps/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -58,7 +59,7 @@ janus/language/treesitter/_tests/test_treesitter.py,sha256=fmr_mFSja7vaCVu0TVyLD
|
|
58
59
|
janus/language/treesitter/treesitter.py,sha256=q7fqfFxt7QsqM6tE39uqutRMsOfEgBd3omv7zVZSEOc,7517
|
59
60
|
janus/llm/__init__.py,sha256=TKLYvnsWKWfxMucy-lCLQ-4bkN9ENotJZDywDEQmrKg,45
|
60
61
|
janus/llm/model_callbacks.py,sha256=cHRZBpYgAwiYbA2k0GQ7DBwBFQZJpEGMUBV3Q_5GTpU,7940
|
61
|
-
janus/llm/models_info.py,sha256=
|
62
|
+
janus/llm/models_info.py,sha256=6ImXTgCeNkMPtW-9swdaWXISixb-UUqq6OCUl8kPxCs,10612
|
62
63
|
janus/metrics/__init__.py,sha256=AsxtZJUzZiXJPr2ehPPltuYP-ddechjg6X85WZUO7mA,241
|
63
64
|
janus/metrics/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
64
65
|
janus/metrics/_tests/reference.py,sha256=hiaJPP9CXkvFBV_wL-gOe_BzELTw0nvB6uCxhxtIiE8,13
|
@@ -89,21 +90,24 @@ janus/parsers/code_parser.py,sha256=3l0HfzgrvJuiwk779s9ZsgUl3xbp1nE1qZxh8aDYRBI,
|
|
89
90
|
janus/parsers/doc_parser.py,sha256=0pUsNZ9hKQLjIi8L8BgkOBHQZ_EGoFLHrBQ4hoDkjSw,5862
|
90
91
|
janus/parsers/eval_parser.py,sha256=Gjh6aTZgpYd2ASJUEPMo4LpCL00cBmbOqc4KM3hy8x8,2922
|
91
92
|
janus/parsers/parser.py,sha256=y6VV64bgVidf-oEFla3I--_28tnJsPBc6QUD_SkbfSE,1614
|
93
|
+
janus/parsers/partition_parser.py,sha256=z9EoqttHacegZzhkoGa-j4vxuzaleDuq32FonzaXsW8,4974
|
92
94
|
janus/parsers/reqs_parser.py,sha256=uRQC41Iqp22GjIvakb5UKv70UWHkcOTbOVl_RDnipYw,2438
|
93
95
|
janus/parsers/uml.py,sha256=SwaoG9QrHKQP8rSxlf3qu_rp7OMQqYSmLgDYBapOa9M,3379
|
94
96
|
janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
95
97
|
janus/prompts/prompt.py,sha256=3796YXIzzIec9b0iUzd8VZlq-AdQbzq8qUGXLy4KH-0,10586
|
96
|
-
janus/refiners/refiner.py,sha256=
|
97
|
-
janus/
|
98
|
+
janus/refiners/refiner.py,sha256=f2YDLnG2TF3Kws40chVOBQ91DD6zf2B1wcoP6WeQcIk,3829
|
99
|
+
janus/refiners/uml.py,sha256=ZFvFLxOdbolYuOmZh_8K6kiHCWKuudqP71sr_TammxM,866
|
100
|
+
janus/retrievers/retriever.py,sha256=n6MzoNZs0GJCH4eqQPS3gFlVHZ3eETr7FuHYbyPzTuo,3506
|
98
101
|
janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
99
102
|
janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
100
103
|
janus/utils/_tests/test_logger.py,sha256=jkkvrCTKwsFCsZtmyuvc-WJ0rC7LJi2Z91sIe4IiKzA,2209
|
101
104
|
janus/utils/_tests/test_progress.py,sha256=Rs_u5PiGjP-L-o6C1fhwfE1ig8jYu9Xo9s4p8yPysl8,491
|
102
105
|
janus/utils/enums.py,sha256=AoilbdiYyMvY2Mp0AM4xlbLSELfut2XMwhIM1S_msP4,27610
|
103
106
|
janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
|
107
|
+
janus/utils/pdf_docs_reader.py,sha256=beMKHdYrFwg0m_i7n0OTJrut3sf4rEWFd7P_80A76WY,5140
|
104
108
|
janus/utils/progress.py,sha256=PIpcQec7SrhsfqB25LHj2CDDkfm9umZx90d9LZnAx6k,1469
|
105
|
-
janus_llm-4.
|
106
|
-
janus_llm-4.
|
107
|
-
janus_llm-4.
|
108
|
-
janus_llm-4.
|
109
|
-
janus_llm-4.
|
109
|
+
janus_llm-4.2.0.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
|
110
|
+
janus_llm-4.2.0.dist-info/METADATA,sha256=5iwBiBTpucpwF3UxClv2P25y9QOpaWsaEGFFyF7mmTU,4574
|
111
|
+
janus_llm-4.2.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
112
|
+
janus_llm-4.2.0.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
|
113
|
+
janus_llm-4.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|