janus-llm 3.5.2__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +90 -42
- janus/converter/converter.py +111 -142
- janus/converter/diagram.py +21 -109
- janus/converter/translate.py +1 -1
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +16 -11
- janus/language/binary/_tests/test_binary.py +1 -1
- janus/language/binary/binary.py +2 -2
- janus/language/mumps/_tests/test_mumps.py +1 -1
- janus/language/mumps/mumps.py +2 -3
- janus/language/naive/simple_ast.py +3 -2
- janus/language/splitter.py +7 -4
- janus/language/treesitter/_tests/test_treesitter.py +1 -1
- janus/language/treesitter/treesitter.py +2 -2
- janus/llm/model_callbacks.py +13 -0
- janus/llm/models_info.py +118 -71
- janus/metrics/metric.py +15 -14
- janus/parsers/uml.py +60 -23
- janus/refiners/refiner.py +106 -64
- janus/retrievers/retriever.py +42 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/METADATA +1 -1
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/RECORD +26 -26
- janus/parsers/refiner_parser.py +0 -46
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/LICENSE +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/WHEEL +0 -0
- {janus_llm-3.5.2.dist-info → janus_llm-4.0.0.dist-info}/entry_points.txt +0 -0
    
        janus/__init__.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning | |
| 5 5 | 
             
            from janus.converter.translate import Translator
         | 
| 6 6 | 
             
            from janus.metrics import *  # noqa: F403
         | 
| 7 7 |  | 
| 8 | 
            -
            __version__ = " | 
| 8 | 
            +
            __version__ = "4.0.0"
         | 
| 9 9 |  | 
| 10 10 | 
             
            # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
         | 
| 11 11 | 
             
            warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
         | 
    
        janus/cli.py
    CHANGED
    
    | @@ -1,6 +1,7 @@ | |
| 1 1 | 
             
            import json
         | 
| 2 2 | 
             
            import logging
         | 
| 3 3 | 
             
            import os
         | 
| 4 | 
            +
            import subprocess  # nosec
         | 
| 4 5 | 
             
            from pathlib import Path
         | 
| 5 6 | 
             
            from typing import List, Optional
         | 
| 6 7 |  | 
| @@ -42,6 +43,7 @@ from janus.llm.models_info import ( | |
| 42 43 | 
             
                openai_models,
         | 
| 43 44 | 
             
            )
         | 
| 44 45 | 
             
            from janus.metrics.cli import evaluate
         | 
| 46 | 
            +
            from janus.refiners.refiner import REFINERS
         | 
| 45 47 | 
             
            from janus.utils.enums import LANGUAGES
         | 
| 46 48 | 
             
            from janus.utils.logger import create_logger
         | 
| 47 49 |  | 
| @@ -241,6 +243,24 @@ def translate( | |
| 241 243 | 
             
                        click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
         | 
| 242 244 | 
             
                    ),
         | 
| 243 245 | 
             
                ] = "file",
         | 
| 246 | 
            +
                refiner_type: Annotated[
         | 
| 247 | 
            +
                    str,
         | 
| 248 | 
            +
                    typer.Option(
         | 
| 249 | 
            +
                        "-r",
         | 
| 250 | 
            +
                        "--refiner",
         | 
| 251 | 
            +
                        help="Name of custom refiner to use",
         | 
| 252 | 
            +
                        click_type=click.Choice(list(REFINERS.keys())),
         | 
| 253 | 
            +
                    ),
         | 
| 254 | 
            +
                ] = "none",
         | 
| 255 | 
            +
                retriever_type: Annotated[
         | 
| 256 | 
            +
                    str,
         | 
| 257 | 
            +
                    typer.Option(
         | 
| 258 | 
            +
                        "-R",
         | 
| 259 | 
            +
                        "--retriever",
         | 
| 260 | 
            +
                        help="Name of custom retriever to use",
         | 
| 261 | 
            +
                        click_type=click.Choice(["active_usings"]),
         | 
| 262 | 
            +
                    ),
         | 
| 263 | 
            +
                ] = None,
         | 
| 244 264 | 
             
                max_tokens: Annotated[
         | 
| 245 265 | 
             
                    int,
         | 
| 246 266 | 
             
                    typer.Option(
         | 
| @@ -250,13 +270,6 @@ def translate( | |
| 250 270 | 
             
                        "If unspecificed, model's default max will be used.",
         | 
| 251 271 | 
             
                    ),
         | 
| 252 272 | 
             
                ] = None,
         | 
| 253 | 
            -
                skip_refiner: Annotated[
         | 
| 254 | 
            -
                    bool,
         | 
| 255 | 
            -
                    typer.Option(
         | 
| 256 | 
            -
                        "--skip-refiner",
         | 
| 257 | 
            -
                        help="Whether to skip the refiner for generating output",
         | 
| 258 | 
            -
                    ),
         | 
| 259 | 
            -
                ] = True,
         | 
| 260 273 | 
             
            ):
         | 
| 261 274 | 
             
                try:
         | 
| 262 275 | 
             
                    target_language, target_version = target_lang.split("-")
         | 
| @@ -282,8 +295,8 @@ def translate( | |
| 282 295 | 
             
                    db_path=db_loc,
         | 
| 283 296 | 
             
                    db_config=collections_config,
         | 
| 284 297 | 
             
                    splitter_type=splitter_type,
         | 
| 285 | 
            -
                     | 
| 286 | 
            -
                     | 
| 298 | 
            +
                    refiner_type=refiner_type,
         | 
| 299 | 
            +
                    retriever_type=retriever_type,
         | 
| 287 300 | 
             
                )
         | 
| 288 301 | 
             
                translator.translate(input_dir, output_dir, overwrite, collection)
         | 
| 289 302 |  | 
| @@ -341,14 +354,6 @@ def document( | |
| 341 354 | 
             
                        help="Whether to overwrite existing files in the output directory",
         | 
| 342 355 | 
             
                    ),
         | 
| 343 356 | 
             
                ] = False,
         | 
| 344 | 
            -
                skip_context: Annotated[
         | 
| 345 | 
            -
                    bool,
         | 
| 346 | 
            -
                    typer.Option(
         | 
| 347 | 
            -
                        "--skip-context",
         | 
| 348 | 
            -
                        help="Prompts will include any context information associated with source"
         | 
| 349 | 
            -
                        " code blocks, unless this option is specified",
         | 
| 350 | 
            -
                    ),
         | 
| 351 | 
            -
                ] = False,
         | 
| 352 357 | 
             
                doc_mode: Annotated[
         | 
| 353 358 | 
             
                    str,
         | 
| 354 359 | 
             
                    typer.Option(
         | 
| @@ -396,6 +401,24 @@ def document( | |
| 396 401 | 
             
                        click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
         | 
| 397 402 | 
             
                    ),
         | 
| 398 403 | 
             
                ] = "file",
         | 
| 404 | 
            +
                refiner_type: Annotated[
         | 
| 405 | 
            +
                    str,
         | 
| 406 | 
            +
                    typer.Option(
         | 
| 407 | 
            +
                        "-r",
         | 
| 408 | 
            +
                        "--refiner",
         | 
| 409 | 
            +
                        help="Name of custom refiner to use",
         | 
| 410 | 
            +
                        click_type=click.Choice(list(REFINERS.keys())),
         | 
| 411 | 
            +
                    ),
         | 
| 412 | 
            +
                ] = "none",
         | 
| 413 | 
            +
                retriever_type: Annotated[
         | 
| 414 | 
            +
                    str,
         | 
| 415 | 
            +
                    typer.Option(
         | 
| 416 | 
            +
                        "-R",
         | 
| 417 | 
            +
                        "--retriever",
         | 
| 418 | 
            +
                        help="Name of custom retriever to use",
         | 
| 419 | 
            +
                        click_type=click.Choice(["active_usings"]),
         | 
| 420 | 
            +
                    ),
         | 
| 421 | 
            +
                ] = None,
         | 
| 399 422 | 
             
                max_tokens: Annotated[
         | 
| 400 423 | 
             
                    int,
         | 
| 401 424 | 
             
                    typer.Option(
         | 
| @@ -405,13 +428,6 @@ def document( | |
| 405 428 | 
             
                        "If unspecificed, model's default max will be used.",
         | 
| 406 429 | 
             
                    ),
         | 
| 407 430 | 
             
                ] = None,
         | 
| 408 | 
            -
                skip_refiner: Annotated[
         | 
| 409 | 
            -
                    bool,
         | 
| 410 | 
            -
                    typer.Option(
         | 
| 411 | 
            -
                        "--skip-refiner",
         | 
| 412 | 
            -
                        help="Whether to skip the refiner for generating output",
         | 
| 413 | 
            -
                    ),
         | 
| 414 | 
            -
                ] = True,
         | 
| 415 431 | 
             
            ):
         | 
| 416 432 | 
             
                model_arguments = dict(temperature=temperature)
         | 
| 417 433 | 
             
                collections_config = get_collections_config()
         | 
| @@ -424,8 +440,8 @@ def document( | |
| 424 440 | 
             
                    db_path=db_loc,
         | 
| 425 441 | 
             
                    db_config=collections_config,
         | 
| 426 442 | 
             
                    splitter_type=splitter_type,
         | 
| 427 | 
            -
                     | 
| 428 | 
            -
                     | 
| 443 | 
            +
                    refiner_type=refiner_type,
         | 
| 444 | 
            +
                    retriever_type=retriever_type,
         | 
| 429 445 | 
             
                )
         | 
| 430 446 | 
             
                if doc_mode == "madlibs":
         | 
| 431 447 | 
             
                    documenter = MadLibsDocumenter(
         | 
| @@ -614,14 +630,6 @@ def diagram( | |
| 614 630 | 
             
                        help="Whether to overwrite existing files in the output directory",
         | 
| 615 631 | 
             
                    ),
         | 
| 616 632 | 
             
                ] = False,
         | 
| 617 | 
            -
                skip_context: Annotated[
         | 
| 618 | 
            -
                    bool,
         | 
| 619 | 
            -
                    typer.Option(
         | 
| 620 | 
            -
                        "--skip-context",
         | 
| 621 | 
            -
                        help="Prompts will include any context information associated with source"
         | 
| 622 | 
            -
                        " code blocks, unless this option is specified",
         | 
| 623 | 
            -
                    ),
         | 
| 624 | 
            -
                ] = False,
         | 
| 625 633 | 
             
                temperature: Annotated[
         | 
| 626 634 | 
             
                    float,
         | 
| 627 635 | 
             
                    typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
         | 
| @@ -658,13 +666,24 @@ def diagram( | |
| 658 666 | 
             
                        click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
         | 
| 659 667 | 
             
                    ),
         | 
| 660 668 | 
             
                ] = "file",
         | 
| 661 | 
            -
                 | 
| 662 | 
            -
                     | 
| 669 | 
            +
                refiner_type: Annotated[
         | 
| 670 | 
            +
                    str,
         | 
| 663 671 | 
             
                    typer.Option(
         | 
| 664 | 
            -
                        " | 
| 665 | 
            -
                         | 
| 672 | 
            +
                        "-r",
         | 
| 673 | 
            +
                        "--refiner",
         | 
| 674 | 
            +
                        help="Name of custom refiner to use",
         | 
| 675 | 
            +
                        click_type=click.Choice(list(REFINERS.keys())),
         | 
| 666 676 | 
             
                    ),
         | 
| 667 | 
            -
                ] =  | 
| 677 | 
            +
                ] = "none",
         | 
| 678 | 
            +
                retriever_type: Annotated[
         | 
| 679 | 
            +
                    str,
         | 
| 680 | 
            +
                    typer.Option(
         | 
| 681 | 
            +
                        "-R",
         | 
| 682 | 
            +
                        "--retriever",
         | 
| 683 | 
            +
                        help="Name of custom retriever to use",
         | 
| 684 | 
            +
                        click_type=click.Choice(["active_usings"]),
         | 
| 685 | 
            +
                    ),
         | 
| 686 | 
            +
                ] = None,
         | 
| 668 687 | 
             
            ):
         | 
| 669 688 | 
             
                model_arguments = dict(temperature=temperature)
         | 
| 670 689 | 
             
                collections_config = get_collections_config()
         | 
| @@ -675,11 +694,11 @@ def diagram( | |
| 675 694 | 
             
                    max_prompts=max_prompts,
         | 
| 676 695 | 
             
                    db_path=db_loc,
         | 
| 677 696 | 
             
                    db_config=collections_config,
         | 
| 697 | 
            +
                    splitter_type=splitter_type,
         | 
| 698 | 
            +
                    refiner_type=refiner_type,
         | 
| 699 | 
            +
                    retriever_type=retriever_type,
         | 
| 678 700 | 
             
                    diagram_type=diagram_type,
         | 
| 679 701 | 
             
                    add_documentation=add_documentation,
         | 
| 680 | 
            -
                    splitter_type=splitter_type,
         | 
| 681 | 
            -
                    skip_refiner=skip_refiner,
         | 
| 682 | 
            -
                    skip_context=skip_context,
         | 
| 683 702 | 
             
                )
         | 
| 684 703 | 
             
                diagram_generator.translate(input_dir, output_dir, overwrite, collection)
         | 
| 685 704 |  | 
| @@ -1156,5 +1175,34 @@ app.add_typer(evaluate, name="evaluate") | |
| 1156 1175 | 
             
            app.add_typer(embedding, name="embedding")
         | 
| 1157 1176 |  | 
| 1158 1177 |  | 
| 1178 | 
            +
            @app.command()
         | 
| 1179 | 
            +
            def render(
         | 
| 1180 | 
            +
                input_dir: Annotated[
         | 
| 1181 | 
            +
                    str,
         | 
| 1182 | 
            +
                    typer.Option(
         | 
| 1183 | 
            +
                        "--input",
         | 
| 1184 | 
            +
                        "-i",
         | 
| 1185 | 
            +
                    ),
         | 
| 1186 | 
            +
                ],
         | 
| 1187 | 
            +
                output_dir: Annotated[str, typer.Option("--output", "-o")],
         | 
| 1188 | 
            +
            ):
         | 
| 1189 | 
            +
                input_dir = Path(input_dir)
         | 
| 1190 | 
            +
                output_dir = Path(output_dir)
         | 
| 1191 | 
            +
                for input_file in input_dir.rglob("*.json"):
         | 
| 1192 | 
            +
                    with open(input_file, "r") as f:
         | 
| 1193 | 
            +
                        data = json.load(f)
         | 
| 1194 | 
            +
             | 
| 1195 | 
            +
                    output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
         | 
| 1196 | 
            +
                    if not output_file.parent.exists():
         | 
| 1197 | 
            +
                        output_file.parent.mkdir()
         | 
| 1198 | 
            +
             | 
| 1199 | 
            +
                    text = data["output"].replace("\\n", "\n").strip()
         | 
| 1200 | 
            +
                    output_file.write_text(text)
         | 
| 1201 | 
            +
             | 
| 1202 | 
            +
                    jar_path = homedir / ".janus/lib/plantuml.jar"
         | 
| 1203 | 
            +
                    subprocess.run(["java", "-jar", jar_path, output_file])  # nosec
         | 
| 1204 | 
            +
                    output_file.unlink()
         | 
| 1205 | 
            +
             | 
| 1206 | 
            +
             | 
| 1159 1207 | 
             
            if __name__ == "__main__":
         | 
| 1160 1208 | 
             
                app()
         | 
    
        janus/converter/converter.py
    CHANGED
    
    | @@ -2,13 +2,11 @@ import functools | |
| 2 2 | 
             
            import json
         | 
| 3 3 | 
             
            import time
         | 
| 4 4 | 
             
            from pathlib import Path
         | 
| 5 | 
            -
            from typing import Any | 
| 5 | 
            +
            from typing import Any
         | 
| 6 6 |  | 
| 7 | 
            -
            from langchain.output_parsers import RetryWithErrorOutputParser
         | 
| 8 7 | 
             
            from langchain_core.exceptions import OutputParserException
         | 
| 9 | 
            -
            from langchain_core. | 
| 10 | 
            -
            from langchain_core. | 
| 11 | 
            -
            from langchain_core.runnables import RunnableLambda, RunnableParallel
         | 
| 8 | 
            +
            from langchain_core.prompts import ChatPromptTemplate
         | 
| 9 | 
            +
            from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
         | 
| 12 10 | 
             
            from openai import BadRequestError, RateLimitError
         | 
| 13 11 | 
             
            from pydantic import ValidationError
         | 
| 14 12 |  | 
| @@ -22,12 +20,18 @@ from janus.language.splitter import ( | |
| 22 20 | 
             
                Splitter,
         | 
| 23 21 | 
             
                TokenLimitError,
         | 
| 24 22 | 
             
            )
         | 
| 25 | 
            -
            from janus.llm import load_model
         | 
| 26 23 | 
             
            from janus.llm.model_callbacks import get_model_callback
         | 
| 27 | 
            -
            from janus.llm.models_info import MODEL_PROMPT_ENGINES
         | 
| 24 | 
            +
            from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
         | 
| 28 25 | 
             
            from janus.parsers.parser import GenericParser, JanusParser
         | 
| 29 | 
            -
            from janus. | 
| 30 | 
            -
             | 
| 26 | 
            +
            from janus.refiners.refiner import (
         | 
| 27 | 
            +
                FixParserExceptions,
         | 
| 28 | 
            +
                HallucinationRefiner,
         | 
| 29 | 
            +
                JanusRefiner,
         | 
| 30 | 
            +
                ReflectionRefiner,
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # from janus.refiners.refiner import BasicRefiner, Refiner
         | 
| 34 | 
            +
            from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
         | 
| 31 35 | 
             
            from janus.utils.enums import LANGUAGES
         | 
| 32 36 | 
             
            from janus.utils.logger import create_logger
         | 
| 33 37 |  | 
| @@ -74,9 +78,8 @@ class Converter: | |
| 74 78 | 
             
                    protected_node_types: tuple[str, ...] = (),
         | 
| 75 79 | 
             
                    prune_node_types: tuple[str, ...] = (),
         | 
| 76 80 | 
             
                    splitter_type: str = "file",
         | 
| 77 | 
            -
                    refiner_type: str =  | 
| 78 | 
            -
                     | 
| 79 | 
            -
                    skip_context: bool = False,
         | 
| 81 | 
            +
                    refiner_type: str | None = None,
         | 
| 82 | 
            +
                    retriever_type: str | None = None,
         | 
| 80 83 | 
             
                ) -> None:
         | 
| 81 84 | 
             
                    """Initialize a Converter instance.
         | 
| 82 85 |  | 
| @@ -96,9 +99,13 @@ class Converter: | |
| 96 99 | 
             
                        prune_node_types: A set of node types which should be pruned.
         | 
| 97 100 | 
             
                        splitter_type: The type of splitter to use. Valid values are `"file"`,
         | 
| 98 101 | 
             
                            `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
         | 
| 99 | 
            -
                        refiner_type: The type of refiner to use. Valid values | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            +
                        refiner_type: The type of refiner to use. Valid values:
         | 
| 103 | 
            +
                            - "parser"
         | 
| 104 | 
            +
                            - "reflection"
         | 
| 105 | 
            +
                            - None
         | 
| 106 | 
            +
                        retriever_type: The type of retriever to use. Valid values:
         | 
| 107 | 
            +
                            - "active_usings"
         | 
| 108 | 
            +
                            - None
         | 
| 102 109 | 
             
                    """
         | 
| 103 110 | 
             
                    self._changed_attrs: set = set()
         | 
| 104 111 |  | 
| @@ -107,7 +114,6 @@ class Converter: | |
| 107 114 | 
             
                    self.override_token_limit: bool = max_tokens is not None
         | 
| 108 115 |  | 
| 109 116 | 
             
                    self._model_name: str
         | 
| 110 | 
            -
                    self._model_id: str
         | 
| 111 117 | 
             
                    self._custom_model_arguments: dict[str, Any]
         | 
| 112 118 |  | 
| 113 119 | 
             
                    self._source_language: str
         | 
| @@ -120,24 +126,26 @@ class Converter: | |
| 120 126 | 
             
                    self._prune_node_types: tuple[str, ...] = ()
         | 
| 121 127 | 
             
                    self._max_tokens: int | None = max_tokens
         | 
| 122 128 | 
             
                    self._prompt_template_name: str
         | 
| 123 | 
            -
                    self._splitter_type: str
         | 
| 124 129 | 
             
                    self._db_path: str | None
         | 
| 125 130 | 
             
                    self._db_config: dict[str, Any] | None
         | 
| 126 131 |  | 
| 127 | 
            -
                    self. | 
| 128 | 
            -
                    self._llm: BaseLanguageModel
         | 
| 132 | 
            +
                    self._llm: JanusModel
         | 
| 129 133 | 
             
                    self._prompt: ChatPromptTemplate
         | 
| 130 134 |  | 
| 131 135 | 
             
                    self._parser: JanusParser = GenericParser()
         | 
| 132 136 | 
             
                    self._combiner: Combiner = Combiner()
         | 
| 133 137 |  | 
| 134 | 
            -
                    self. | 
| 135 | 
            -
                    self. | 
| 138 | 
            +
                    self._splitter_type: str
         | 
| 139 | 
            +
                    self._refiner_type: str | None
         | 
| 140 | 
            +
                    self._retriever_type: str | None
         | 
| 136 141 |  | 
| 137 | 
            -
                    self. | 
| 142 | 
            +
                    self._splitter: Splitter
         | 
| 143 | 
            +
                    self._refiner: JanusRefiner
         | 
| 144 | 
            +
                    self._retriever: JanusRetriever
         | 
| 138 145 |  | 
| 139 146 | 
             
                    self.set_splitter(splitter_type=splitter_type)
         | 
| 140 147 | 
             
                    self.set_refiner(refiner_type=refiner_type)
         | 
| 148 | 
            +
                    self.set_retriever(retriever_type=retriever_type)
         | 
| 141 149 | 
             
                    self.set_model(model_name=model, **model_arguments)
         | 
| 142 150 | 
             
                    self.set_prompt(prompt_template=prompt_template)
         | 
| 143 151 | 
             
                    self.set_source_language(source_language)
         | 
| @@ -146,8 +154,6 @@ class Converter: | |
| 146 154 | 
             
                    self.set_db_path(db_path=db_path)
         | 
| 147 155 | 
             
                    self.set_db_config(db_config=db_config)
         | 
| 148 156 |  | 
| 149 | 
            -
                    self.skip_context = skip_context
         | 
| 150 | 
            -
             | 
| 151 157 | 
             
                    # Child class must call this. Should we enforce somehow?
         | 
| 152 158 | 
             
                    # self._load_parameters()
         | 
| 153 159 |  | 
| @@ -163,9 +169,11 @@ class Converter: | |
| 163 169 | 
             
                def _load_parameters(self) -> None:
         | 
| 164 170 | 
             
                    self._load_model()
         | 
| 165 171 | 
             
                    self._load_prompt()
         | 
| 172 | 
            +
                    self._load_retriever()
         | 
| 173 | 
            +
                    self._load_refiner()
         | 
| 166 174 | 
             
                    self._load_splitter()
         | 
| 167 175 | 
             
                    self._load_vectorizer()
         | 
| 168 | 
            -
                    self. | 
| 176 | 
            +
                    self._load_chain()
         | 
| 169 177 | 
             
                    self._changed_attrs.clear()
         | 
| 170 178 |  | 
| 171 179 | 
             
                def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
         | 
| @@ -184,8 +192,6 @@ class Converter: | |
| 184 192 | 
             
                def set_prompt(self, prompt_template: str) -> None:
         | 
| 185 193 | 
             
                    """Validate and set the prompt template name.
         | 
| 186 194 |  | 
| 187 | 
            -
                    The affected objects will not be updated until translate() is called.
         | 
| 188 | 
            -
             | 
| 189 195 | 
             
                    Arguments:
         | 
| 190 196 | 
             
                        prompt_template: name of prompt template directory
         | 
| 191 197 | 
             
                            (see janus/prompts/templates) or path to a directory.
         | 
| @@ -195,29 +201,34 @@ class Converter: | |
| 195 201 | 
             
                def set_splitter(self, splitter_type: str) -> None:
         | 
| 196 202 | 
             
                    """Validate and set the prompt template name.
         | 
| 197 203 |  | 
| 198 | 
            -
                    The affected objects will not be updated until translate() is called.
         | 
| 199 | 
            -
             | 
| 200 204 | 
             
                    Arguments:
         | 
| 201 205 | 
             
                        prompt_template: name of prompt template directory
         | 
| 202 206 | 
             
                            (see janus/prompts/templates) or path to a directory.
         | 
| 203 207 | 
             
                    """
         | 
| 204 | 
            -
                     | 
| 208 | 
            +
                    if splitter_type not in CUSTOM_SPLITTERS:
         | 
| 209 | 
            +
                        raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
         | 
| 205 210 |  | 
| 206 | 
            -
             | 
| 207 | 
            -
                    """Validate and set the refiner name
         | 
| 211 | 
            +
                    self._splitter_type = splitter_type
         | 
| 208 212 |  | 
| 209 | 
            -
             | 
| 213 | 
            +
                def set_refiner(self, refiner_type: str | None) -> None:
         | 
| 214 | 
            +
                    """Validate and set the refiner type
         | 
| 210 215 |  | 
| 211 216 | 
             
                    Arguments:
         | 
| 212 | 
            -
                        refiner_type: the  | 
| 217 | 
            +
                        refiner_type: the type of refiner to use
         | 
| 213 218 | 
             
                    """
         | 
| 214 219 | 
             
                    self._refiner_type = refiner_type
         | 
| 215 220 |  | 
| 221 | 
            +
                def set_retriever(self, retriever_type: str | None) -> None:
         | 
| 222 | 
            +
                    """Validate and set the retriever type
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    Arguments:
         | 
| 225 | 
            +
                        retriever_type: the type of retriever to use
         | 
| 226 | 
            +
                    """
         | 
| 227 | 
            +
                    self._retriever_type = retriever_type
         | 
| 228 | 
            +
             | 
| 216 229 | 
             
                def set_source_language(self, source_language: str) -> None:
         | 
| 217 230 | 
             
                    """Validate and set the source language.
         | 
| 218 231 |  | 
| 219 | 
            -
                    The affected objects will not be updated until _load_parameters() is called.
         | 
| 220 | 
            -
             | 
| 221 232 | 
             
                    Arguments:
         | 
| 222 233 | 
             
                        source_language: The source programming language.
         | 
| 223 234 | 
             
                    """
         | 
| @@ -287,20 +298,6 @@ class Converter: | |
| 287 298 |  | 
| 288 299 | 
             
                    self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
         | 
| 289 300 |  | 
| 290 | 
            -
                @run_if_changed("_refiner_type", "_model_name")
         | 
| 291 | 
            -
                def _load_refiner(self) -> None:
         | 
| 292 | 
            -
                    """Load the refiner according to this instance's attributes.
         | 
| 293 | 
            -
             | 
| 294 | 
            -
                    If the relevant fields have not been changed since the last time this method was
         | 
| 295 | 
            -
                    called, nothing happens.
         | 
| 296 | 
            -
                    """
         | 
| 297 | 
            -
                    if self._refiner_type == "basic":
         | 
| 298 | 
            -
                        self._refiner = BasicRefiner(
         | 
| 299 | 
            -
                            "basic_refinement", self._model_id, self._source_language
         | 
| 300 | 
            -
                        )
         | 
| 301 | 
            -
                    else:
         | 
| 302 | 
            -
                        raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
         | 
| 303 | 
            -
             | 
| 304 301 | 
             
                @run_if_changed("_model_name", "_custom_model_arguments")
         | 
| 305 302 | 
             
                def _load_model(self) -> None:
         | 
| 306 303 | 
             
                    """Load the model according to this instance's attributes.
         | 
| @@ -314,9 +311,9 @@ class Converter: | |
| 314 311 | 
             
                    # model_arguments.update(self._custom_model_arguments)
         | 
| 315 312 |  | 
| 316 313 | 
             
                    # Load the model
         | 
| 317 | 
            -
                    self._llm | 
| 318 | 
            -
             | 
| 319 | 
            -
             | 
| 314 | 
            +
                    self._llm = load_model(self._model_name)
         | 
| 315 | 
            +
                    token_limit = self._llm.token_limit
         | 
| 316 | 
            +
             | 
| 320 317 | 
             
                    # Set the max_tokens to less than half the model's limit to allow for enough
         | 
| 321 318 | 
             
                    # tokens at output
         | 
| 322 319 | 
             
                    # Only modify max_tokens if it is not specified by user
         | 
| @@ -335,7 +332,7 @@ class Converter: | |
| 335 332 | 
             
                    If the relevant fields have not been changed since the last time this
         | 
| 336 333 | 
             
                    method was called, nothing happens.
         | 
| 337 334 | 
             
                    """
         | 
| 338 | 
            -
                    prompt_engine = MODEL_PROMPT_ENGINES[self. | 
| 335 | 
            +
                    prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
         | 
| 339 336 | 
             
                        source_language=self._source_language,
         | 
| 340 337 | 
             
                        prompt_template=self._prompt_template_name,
         | 
| 341 338 | 
             
                    )
         | 
| @@ -354,6 +351,59 @@ class Converter: | |
| 354 351 | 
             
                        self._db_path, self._db_config
         | 
| 355 352 | 
             
                    )
         | 
| 356 353 |  | 
| 354 | 
            +
                @run_if_changed("_retriever_type")
         | 
| 355 | 
            +
                def _load_retriever(self):
         | 
| 356 | 
            +
                    if self._retriever_type == "active_usings":
         | 
| 357 | 
            +
                        self._retriever = ActiveUsingsRetriever()
         | 
| 358 | 
            +
                    else:
         | 
| 359 | 
            +
                        self._retriever = JanusRetriever()
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                @run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
         | 
| 362 | 
            +
                def _load_refiner(self) -> None:
         | 
| 363 | 
            +
                    """Load the refiner according to this instance's attributes.
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    If the relevant fields have not been changed since the last time this method was
         | 
| 366 | 
            +
                    called, nothing happens.
         | 
| 367 | 
            +
                    """
         | 
| 368 | 
            +
                    if self._refiner_type == "parser":
         | 
| 369 | 
            +
                        self._refiner = FixParserExceptions(
         | 
| 370 | 
            +
                            llm=self._llm,
         | 
| 371 | 
            +
                            parser=self._parser,
         | 
| 372 | 
            +
                            max_retries=self.max_prompts,
         | 
| 373 | 
            +
                        )
         | 
| 374 | 
            +
                    elif self._refiner_type == "reflection":
         | 
| 375 | 
            +
                        self._refiner = ReflectionRefiner(
         | 
| 376 | 
            +
                            llm=self._llm,
         | 
| 377 | 
            +
                            parser=self._parser,
         | 
| 378 | 
            +
                            max_retries=self.max_prompts,
         | 
| 379 | 
            +
                        )
         | 
| 380 | 
            +
                    elif self._refiner_type == "hallucination":
         | 
| 381 | 
            +
                        self._refiner = HallucinationRefiner(
         | 
| 382 | 
            +
                            llm=self._llm,
         | 
| 383 | 
            +
                            parser=self._parser,
         | 
| 384 | 
            +
                            max_retries=self.max_prompts,
         | 
| 385 | 
            +
                        )
         | 
| 386 | 
            +
                    else:
         | 
| 387 | 
            +
                        self._refiner = JanusRefiner(parser=self._parser)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
         | 
| 390 | 
            +
                def _load_chain(self):
         | 
| 391 | 
            +
                    self.chain = (
         | 
| 392 | 
            +
                        self._input_runnable()
         | 
| 393 | 
            +
                        | self._prompt
         | 
| 394 | 
            +
                        | RunnableParallel(
         | 
| 395 | 
            +
                            completion=self._llm,
         | 
| 396 | 
            +
                            prompt_value=RunnablePassthrough(),
         | 
| 397 | 
            +
                        )
         | 
| 398 | 
            +
                        | self._refiner.parse_runnable
         | 
| 399 | 
            +
                    )
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                def _input_runnable(self) -> Runnable:
         | 
| 402 | 
            +
                    return RunnableParallel(
         | 
| 403 | 
            +
                        SOURCE_CODE=self._parser.parse_input,
         | 
| 404 | 
            +
                        context=self._retriever,
         | 
| 405 | 
            +
                    )
         | 
| 406 | 
            +
             | 
| 357 407 | 
             
                def translate(
         | 
| 358 408 | 
             
                    self,
         | 
| 359 409 | 
             
                    input_directory: str | Path,
         | 
| @@ -598,110 +648,29 @@ class Converter: | |
| 598 648 | 
             
                    return root
         | 
| 599 649 |  | 
| 600 650 | 
             
                def _run_chain(self, block: TranslatedCodeBlock) -> str:
         | 
| 601 | 
            -
                     | 
| 602 | 
            -
                    First, try to fix simple formatting errors by giving the model just
         | 
| 603 | 
            -
                    the output and the parsing error. After a number of attempts, try
         | 
| 604 | 
            -
                    giving the model the output, the parsing error, and the original
         | 
| 605 | 
            -
                    input. Again check/retry this output to solve for formatting errors.
         | 
| 606 | 
            -
                    If we still haven't succeeded after several attempts, the model may
         | 
| 607 | 
            -
                    be getting thrown off by a bad initial output; start from scratch
         | 
| 608 | 
            -
                    and try again.
         | 
| 609 | 
            -
             | 
| 610 | 
            -
                    The number of tries for each layer of this scheme is roughly equal
         | 
| 611 | 
            -
                    to the cube root of self.max_retries, so the total calls to the
         | 
| 612 | 
            -
                    LLM will be roughly as expected (up to sqrt(self.max_retries) over)
         | 
| 613 | 
            -
                    """
         | 
| 614 | 
            -
                    input = self._parser.parse_input(block.original)
         | 
| 615 | 
            -
             | 
| 616 | 
            -
                    # Retries with just the output and the error
         | 
| 617 | 
            -
                    n1 = round(self.max_prompts ** (1 / 2))
         | 
| 618 | 
            -
             | 
| 619 | 
            -
                    # Retries with the input, output, and error
         | 
| 620 | 
            -
                    n2 = round(self.max_prompts // n1)
         | 
| 621 | 
            -
             | 
| 622 | 
            -
                    if not self.skip_context:
         | 
| 623 | 
            -
                        self._make_prompt_additions(block)
         | 
| 624 | 
            -
                    if not self.skip_refiner:  # Make replacements in the prompt
         | 
| 625 | 
            -
                        refine_output = RefinerParser(
         | 
| 626 | 
            -
                            parser=self._parser,
         | 
| 627 | 
            -
                            initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
         | 
| 628 | 
            -
                            refiner=self._refiner,
         | 
| 629 | 
            -
                            max_retries=n1,
         | 
| 630 | 
            -
                            llm=self._llm,
         | 
| 631 | 
            -
                        )
         | 
| 632 | 
            -
                    else:
         | 
| 633 | 
            -
                        refine_output = RetryWithErrorOutputParser.from_llm(
         | 
| 634 | 
            -
                            llm=self._llm,
         | 
| 635 | 
            -
                            parser=self._parser,
         | 
| 636 | 
            -
                            max_retries=n1,
         | 
| 637 | 
            -
                        )
         | 
| 638 | 
            -
             | 
| 639 | 
            -
                    completion_chain = self._prompt | self._llm
         | 
| 640 | 
            -
                    chain = RunnableParallel(
         | 
| 641 | 
            -
                        completion=completion_chain, prompt_value=self._prompt
         | 
| 642 | 
            -
                    ) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
         | 
| 643 | 
            -
                    for _ in range(n2):
         | 
| 644 | 
            -
                        try:
         | 
| 645 | 
            -
                            return chain.invoke({"SOURCE_CODE": input})
         | 
| 646 | 
            -
                        except OutputParserException:
         | 
| 647 | 
            -
                            pass
         | 
| 648 | 
            -
             | 
| 649 | 
            -
                    raise OutputParserException(f"Failed to parse after {n1*n2} retries")
         | 
| 651 | 
            +
                    return self.chain.invoke(block.original)
         | 
| 650 652 |  | 
| 651 653 | 
             
                def _get_output_obj(
         | 
| 652 654 | 
             
                    self, block: TranslatedCodeBlock
         | 
| 653 | 
            -
                ) -> dict[str, int | float | str | dict[str, str]]:
         | 
| 655 | 
            +
                ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
         | 
| 654 656 | 
             
                    output_str = self._parser.parse_combined_output(block.complete_text)
         | 
| 655 657 |  | 
| 656 | 
            -
                     | 
| 658 | 
            +
                    output_obj: str | dict[str, str]
         | 
| 657 659 | 
             
                    try:
         | 
| 658 | 
            -
                         | 
| 660 | 
            +
                        output_obj = json.loads(output_str)
         | 
| 659 661 | 
             
                    except json.JSONDecodeError:
         | 
| 660 | 
            -
                         | 
| 662 | 
            +
                        output_obj = output_str
         | 
| 661 663 |  | 
| 662 664 | 
             
                    return dict(
         | 
| 663 | 
            -
                        input=block.original.text,
         | 
| 665 | 
            +
                        input=block.original.text or "",
         | 
| 664 666 | 
             
                        metadata=dict(
         | 
| 665 667 | 
             
                            retries=block.total_retries,
         | 
| 666 668 | 
             
                            cost=block.total_cost,
         | 
| 667 669 | 
             
                            processing_time=block.processing_time,
         | 
| 668 670 | 
             
                        ),
         | 
| 669 | 
            -
                        output= | 
| 670 | 
            -
                    )
         | 
| 671 | 
            -
             | 
| 672 | 
            -
                @staticmethod
         | 
| 673 | 
            -
                def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
         | 
| 674 | 
            -
                    """Get a list of strings to append to the prompt.
         | 
| 675 | 
            -
             | 
| 676 | 
            -
                    Arguments:
         | 
| 677 | 
            -
                        block: The `TranslatedCodeBlock` to save to a file.
         | 
| 678 | 
            -
                    """
         | 
| 679 | 
            -
                    return [(key, item) for key, item in block.context_tags.items()]
         | 
| 680 | 
            -
             | 
| 681 | 
            -
                def _make_prompt_additions(self, block: CodeBlock):
         | 
| 682 | 
            -
                    # Prepare the additional context to prepend
         | 
| 683 | 
            -
                    additional_context = "".join(
         | 
| 684 | 
            -
                        [
         | 
| 685 | 
            -
                            f"{context_tag}: {context}\n"
         | 
| 686 | 
            -
                            for context_tag, context in self._get_prompt_additions(block)
         | 
| 687 | 
            -
                        ]
         | 
| 671 | 
            +
                        output=output_obj,
         | 
| 688 672 | 
             
                    )
         | 
| 689 673 |  | 
| 690 | 
            -
                    if not hasattr(self._prompt, "messages"):
         | 
| 691 | 
            -
                        log.debug("Skipping additions to prompt, no messages found on prompt object!")
         | 
| 692 | 
            -
                        return
         | 
| 693 | 
            -
             | 
| 694 | 
            -
                    # Iterate through existing messages to find and update the system message
         | 
| 695 | 
            -
                    for i, message in enumerate(self._prompt.messages):
         | 
| 696 | 
            -
                        if isinstance(message, SystemMessagePromptTemplate):
         | 
| 697 | 
            -
                            # Prepend the additional context to the system message
         | 
| 698 | 
            -
                            updated_system_message = SystemMessagePromptTemplate.from_template(
         | 
| 699 | 
            -
                                additional_context + message.prompt.template
         | 
| 700 | 
            -
                            )
         | 
| 701 | 
            -
                            # Directly modify the message in the list
         | 
| 702 | 
            -
                            self._prompt.messages[i] = updated_system_message
         | 
| 703 | 
            -
                            break  # Assuming there's only one system message to update
         | 
| 704 | 
            -
             | 
| 705 674 | 
             
                def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
         | 
| 706 675 | 
             
                    """Save a file to disk.
         | 
| 707 676 |  |