janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/__init__.py CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "4.1.0"
8
+ __version__ = "4.3.1"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/cli.py CHANGED
@@ -13,10 +13,14 @@ from rich.console import Console
13
13
  from rich.prompt import Confirm
14
14
  from typing_extensions import Annotated
15
15
 
16
+ import janus.refiners.refiner
17
+ import janus.refiners.uml
16
18
  from janus.converter.aggregator import Aggregator
17
19
  from janus.converter.converter import Converter
18
20
  from janus.converter.diagram import DiagramGenerator
19
21
  from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
22
+ from janus.converter.evaluate import InlineCommentEvaluator, RequirementEvaluator
23
+ from janus.converter.partition import Partitioner
20
24
  from janus.converter.requirements import RequirementsDocumenter
21
25
  from janus.converter.translate import Translator
22
26
  from janus.embedding.collections import Collections
@@ -44,7 +48,6 @@ from janus.llm.models_info import (
44
48
  openai_models,
45
49
  )
46
50
  from janus.metrics.cli import evaluate
47
- from janus.refiners.refiner import REFINERS
48
51
  from janus.utils.enums import LANGUAGES
49
52
  from janus.utils.logger import create_logger
50
53
 
@@ -69,6 +72,18 @@ with open(db_file, "r") as f:
69
72
  collections_config_file = Path(db_loc) / "collections.json"
70
73
 
71
74
 
75
+ def get_subclasses(cls):
76
+ return set(cls.__subclasses__()).union(
77
+ set(s for c in cls.__subclasses__() for s in get_subclasses(c))
78
+ )
79
+
80
+
81
+ REFINER_TYPES = get_subclasses(janus.refiners.refiner.JanusRefiner).union(
82
+ {janus.refiners.refiner.JanusRefiner}
83
+ )
84
+ REFINERS = {r.__name__: r for r in REFINER_TYPES}
85
+
86
+
72
87
  def get_collections_config():
73
88
  if collections_config_file.exists():
74
89
  with open(collections_config_file, "r") as f:
@@ -113,7 +128,7 @@ embedding = typer.Typer(
113
128
 
114
129
  def version_callback(value: bool) -> None:
115
130
  if value:
116
- from janus import __version__ as version
131
+ from . import __version__ as version
117
132
 
118
133
  print(f"Janus CLI [blue]v{version}[/blue]")
119
134
  raise typer.Exit()
@@ -244,22 +259,23 @@ def translate(
244
259
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
245
260
  ),
246
261
  ] = "file",
247
- refiner_type: Annotated[
248
- str,
262
+ refiner_types: Annotated[
263
+ list[str],
249
264
  typer.Option(
250
265
  "-r",
251
266
  "--refiner",
252
- help="Name of custom refiner to use",
267
+ help="List of refiner types to use. Add -r for each refiner to use in\
268
+ refinement chain",
253
269
  click_type=click.Choice(list(REFINERS.keys())),
254
270
  ),
255
- ] = "none",
271
+ ] = ["JanusRefiner"],
256
272
  retriever_type: Annotated[
257
273
  str,
258
274
  typer.Option(
259
275
  "-R",
260
276
  "--retriever",
261
277
  help="Name of custom retriever to use",
262
- click_type=click.Choice(["active_usings"]),
278
+ click_type=click.Choice(["active_usings", "language_docs"]),
263
279
  ),
264
280
  ] = None,
265
281
  max_tokens: Annotated[
@@ -272,6 +288,7 @@ def translate(
272
288
  ),
273
289
  ] = None,
274
290
  ):
291
+ refiner_types = [REFINERS[r] for r in refiner_types]
275
292
  try:
276
293
  target_language, target_version = target_lang.split("-")
277
294
  except ValueError:
@@ -296,7 +313,7 @@ def translate(
296
313
  db_path=db_loc,
297
314
  db_config=collections_config,
298
315
  splitter_type=splitter_type,
299
- refiner_type=refiner_type,
316
+ refiner_types=refiner_types,
300
317
  retriever_type=retriever_type,
301
318
  )
302
319
  translator.translate(input_dir, output_dir, overwrite, collection)
@@ -402,22 +419,23 @@ def document(
402
419
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
403
420
  ),
404
421
  ] = "file",
405
- refiner_type: Annotated[
406
- str,
422
+ refiner_types: Annotated[
423
+ list[str],
407
424
  typer.Option(
408
425
  "-r",
409
426
  "--refiner",
410
- help="Name of custom refiner to use",
427
+ help="List of refiner types to use. Add -r for each refiner to use in\
428
+ refinement chain",
411
429
  click_type=click.Choice(list(REFINERS.keys())),
412
430
  ),
413
- ] = "none",
431
+ ] = ["JanusRefiner"],
414
432
  retriever_type: Annotated[
415
433
  str,
416
434
  typer.Option(
417
435
  "-R",
418
436
  "--retriever",
419
437
  help="Name of custom retriever to use",
420
- click_type=click.Choice(["active_usings"]),
438
+ click_type=click.Choice(["active_usings", "language_docs"]),
421
439
  ),
422
440
  ] = None,
423
441
  max_tokens: Annotated[
@@ -430,6 +448,7 @@ def document(
430
448
  ),
431
449
  ] = None,
432
450
  ):
451
+ refiner_types = [REFINERS[r] for r in refiner_types]
433
452
  model_arguments = dict(temperature=temperature)
434
453
  collections_config = get_collections_config()
435
454
  kwargs = dict(
@@ -441,7 +460,7 @@ def document(
441
460
  db_path=db_loc,
442
461
  db_config=collections_config,
443
462
  splitter_type=splitter_type,
444
- refiner_type=refiner_type,
463
+ refiner_types=refiner_types,
445
464
  retriever_type=retriever_type,
446
465
  )
447
466
  if doc_mode == "madlibs":
@@ -458,12 +477,6 @@ def document(
458
477
  documenter.translate(input_dir, output_dir, overwrite, collection)
459
478
 
460
479
 
461
- def get_subclasses(cls):
462
- return set(cls.__subclasses__()).union(
463
- set(s for c in cls.__subclasses__() for s in get_subclasses(c))
464
- )
465
-
466
-
467
480
  @app.command()
468
481
  def aggregate(
469
482
  input_dir: Annotated[
@@ -578,6 +591,115 @@ def aggregate(
578
591
  aggregator.translate(input_dir, output_dir, overwrite, collection)
579
592
 
580
593
 
594
+ @app.command(
595
+ help="Partition input code using an LLM.",
596
+ no_args_is_help=True,
597
+ )
598
+ def partition(
599
+ input_dir: Annotated[
600
+ Path,
601
+ typer.Option(
602
+ "--input",
603
+ "-i",
604
+ help="The directory containing the source code to be partitioned. ",
605
+ ),
606
+ ],
607
+ language: Annotated[
608
+ str,
609
+ typer.Option(
610
+ "--language",
611
+ "-l",
612
+ help="The language of the source code.",
613
+ click_type=click.Choice(sorted(LANGUAGES)),
614
+ ),
615
+ ],
616
+ output_dir: Annotated[
617
+ Path,
618
+ typer.Option(
619
+ "--output-dir", "-o", help="The directory to store the partitioned code in."
620
+ ),
621
+ ],
622
+ llm_name: Annotated[
623
+ str,
624
+ typer.Option(
625
+ "--llm",
626
+ "-L",
627
+ help="The custom name of the model set with 'janus llm add'.",
628
+ ),
629
+ ] = "gpt-4o",
630
+ max_prompts: Annotated[
631
+ int,
632
+ typer.Option(
633
+ "--max-prompts",
634
+ "-m",
635
+ help="The maximum number of times to prompt a model on one functional block "
636
+ "before exiting the application. This is to prevent wasting too much money.",
637
+ ),
638
+ ] = 10,
639
+ overwrite: Annotated[
640
+ bool,
641
+ typer.Option(
642
+ "--overwrite/--preserve",
643
+ help="Whether to overwrite existing files in the output directory",
644
+ ),
645
+ ] = False,
646
+ temperature: Annotated[
647
+ float,
648
+ typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
649
+ ] = 0.7,
650
+ splitter_type: Annotated[
651
+ str,
652
+ typer.Option(
653
+ "-S",
654
+ "--splitter",
655
+ help="Name of custom splitter to use",
656
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
657
+ ),
658
+ ] = "file",
659
+ refiner_types: Annotated[
660
+ list[str],
661
+ typer.Option(
662
+ "-r",
663
+ "--refiner",
664
+ help="List of refiner types to use. Add -r for each refiner to use in\
665
+ refinement chain",
666
+ click_type=click.Choice(list(REFINERS.keys())),
667
+ ),
668
+ ] = ["JanusRefiner"],
669
+ max_tokens: Annotated[
670
+ int,
671
+ typer.Option(
672
+ "--max-tokens",
673
+ "-M",
674
+ help="The maximum number of tokens the model will take in. "
675
+ "If unspecificed, model's default max will be used.",
676
+ ),
677
+ ] = None,
678
+ partition_token_limit: Annotated[
679
+ int,
680
+ typer.Option(
681
+ "--partition-tokens",
682
+ "-pt",
683
+ help="The limit on the number of tokens per partition.",
684
+ ),
685
+ ] = 8192,
686
+ ):
687
+ refiner_types = [REFINERS[r] for r in refiner_types]
688
+ model_arguments = dict(temperature=temperature)
689
+ kwargs = dict(
690
+ model=llm_name,
691
+ model_arguments=model_arguments,
692
+ source_language=language,
693
+ max_prompts=max_prompts,
694
+ max_tokens=max_tokens,
695
+ splitter_type=splitter_type,
696
+ refiner_types=refiner_types,
697
+ partition_token_limit=partition_token_limit,
698
+ )
699
+ partitioner = Partitioner(**kwargs)
700
+ partitioner.translate(input_dir, output_dir, overwrite)
701
+
702
+
581
703
  @app.command(
582
704
  help="Diagram input code using an LLM.",
583
705
  no_args_is_help=True,
@@ -667,25 +789,27 @@ def diagram(
667
789
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
668
790
  ),
669
791
  ] = "file",
670
- refiner_type: Annotated[
671
- str,
792
+ refiner_types: Annotated[
793
+ list[str],
672
794
  typer.Option(
673
795
  "-r",
674
796
  "--refiner",
675
- help="Name of custom refiner to use",
797
+ help="List of refiner types to use. Add -r for each refiner to use in\
798
+ refinement chain",
676
799
  click_type=click.Choice(list(REFINERS.keys())),
677
800
  ),
678
- ] = "none",
801
+ ] = ["JanusRefiner"],
679
802
  retriever_type: Annotated[
680
803
  str,
681
804
  typer.Option(
682
805
  "-R",
683
806
  "--retriever",
684
807
  help="Name of custom retriever to use",
685
- click_type=click.Choice(["active_usings"]),
808
+ click_type=click.Choice(["active_usings", "language_docs"]),
686
809
  ),
687
810
  ] = None,
688
811
  ):
812
+ refiner_types = [REFINERS[r] for r in refiner_types]
689
813
  model_arguments = dict(temperature=temperature)
690
814
  collections_config = get_collections_config()
691
815
  diagram_generator = DiagramGenerator(
@@ -696,7 +820,7 @@ def diagram(
696
820
  db_path=db_loc,
697
821
  db_config=collections_config,
698
822
  splitter_type=splitter_type,
699
- refiner_type=refiner_type,
823
+ refiner_types=refiner_types,
700
824
  retriever_type=retriever_type,
701
825
  diagram_type=diagram_type,
702
826
  add_documentation=add_documentation,
@@ -704,6 +828,139 @@ def diagram(
704
828
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
705
829
 
706
830
 
831
+ @app.command(
832
+ help="LLM self evaluation",
833
+ no_args_is_help=True,
834
+ )
835
+ def llm_self_eval(
836
+ input_dir: Annotated[
837
+ Path,
838
+ typer.Option(
839
+ "--input",
840
+ "-i",
841
+ help="The directory containing the source code to be evaluated. "
842
+ "The files should all be in one flat directory.",
843
+ ),
844
+ ],
845
+ language: Annotated[
846
+ str,
847
+ typer.Option(
848
+ "--language",
849
+ "-l",
850
+ help="The language of the source code.",
851
+ click_type=click.Choice(sorted(LANGUAGES)),
852
+ ),
853
+ ],
854
+ output_dir: Annotated[
855
+ Path,
856
+ typer.Option(
857
+ "--output-dir", "-o", help="The directory to store the evaluations in."
858
+ ),
859
+ ],
860
+ llm_name: Annotated[
861
+ str,
862
+ typer.Option(
863
+ "--llm",
864
+ "-L",
865
+ help="The custom name of the model set with 'janus llm add'.",
866
+ ),
867
+ ] = "gpt-4o",
868
+ evaluation_type: Annotated[
869
+ str,
870
+ typer.Option(
871
+ "--evaluation-type",
872
+ "-e",
873
+ help="Type of output to evaluate.",
874
+ click_type=click.Choice(["incose", "comments"]),
875
+ ),
876
+ ] = "incose",
877
+ max_prompts: Annotated[
878
+ int,
879
+ typer.Option(
880
+ "--max-prompts",
881
+ "-m",
882
+ help="The maximum number of times to prompt a model on one functional block "
883
+ "before exiting the application. This is to prevent wasting too much money.",
884
+ ),
885
+ ] = 10,
886
+ overwrite: Annotated[
887
+ bool,
888
+ typer.Option(
889
+ "--overwrite/--preserve",
890
+ help="Whether to overwrite existing files in the output directory",
891
+ ),
892
+ ] = False,
893
+ temperature: Annotated[
894
+ float,
895
+ typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
896
+ ] = 0.7,
897
+ collection: Annotated[
898
+ str,
899
+ typer.Option(
900
+ "--collection",
901
+ "-c",
902
+ help="If set, will put the translated result into a Chroma DB "
903
+ "collection with the name provided.",
904
+ ),
905
+ ] = None,
906
+ splitter_type: Annotated[
907
+ str,
908
+ typer.Option(
909
+ "-S",
910
+ "--splitter",
911
+ help="Name of custom splitter to use",
912
+ click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
913
+ ),
914
+ ] = "file",
915
+ refiner_types: Annotated[
916
+ list[str],
917
+ typer.Option(
918
+ "-r",
919
+ "--refiner",
920
+ help="List of refiner types to use. Add -r for each refiner to use in\
921
+ refinement chain",
922
+ click_type=click.Choice(list(REFINERS.keys())),
923
+ ),
924
+ ] = ["JanusRefiner"],
925
+ eval_items_per_request: Annotated[
926
+ int,
927
+ typer.Option(
928
+ "--eval-items-per-request",
929
+ "-rc",
930
+ help="The maximum number of evaluation items per request",
931
+ ),
932
+ ] = None,
933
+ max_tokens: Annotated[
934
+ int,
935
+ typer.Option(
936
+ "--max-tokens",
937
+ "-M",
938
+ help="The maximum number of tokens the model will take in. "
939
+ "If unspecificed, model's default max will be used.",
940
+ ),
941
+ ] = None,
942
+ ):
943
+ model_arguments = dict(temperature=temperature)
944
+ refiner_types = [REFINERS[r] for r in refiner_types]
945
+ kwargs = dict(
946
+ eval_items_per_request=eval_items_per_request,
947
+ model=llm_name,
948
+ model_arguments=model_arguments,
949
+ source_language=language,
950
+ max_prompts=max_prompts,
951
+ max_tokens=max_tokens,
952
+ splitter_type=splitter_type,
953
+ refiner_types=refiner_types,
954
+ )
955
+ # Setting parser type here
956
+ if evaluation_type == "incose":
957
+ evaluator = RequirementEvaluator(**kwargs)
958
+ elif evaluation_type == "comments":
959
+ evaluator = InlineCommentEvaluator(**kwargs)
960
+
961
+ evaluator.translate(input_dir, output_dir, overwrite, collection)
962
+
963
+
707
964
  @db.command("init", help="Connect to or create a database.")
708
965
  def db_init(
709
966
  path: Annotated[
@@ -1005,13 +1262,12 @@ def llm_add(
1005
1262
  show_choices=False,
1006
1263
  )
1007
1264
  params = dict(
1008
- # OpenAI uses the "model_name" key for what we're calling "long_model_id"
1009
- model_name=MODEL_ID_TO_LONG_ID[model_id],
1265
+ model_name=model_name,
1010
1266
  temperature=0.7,
1011
1267
  n=1,
1012
1268
  )
1013
- max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
1014
- model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
1269
+ max_tokens = TOKEN_LIMITS[model_name]
1270
+ model_cost = COST_PER_1K_TOKENS[model_name]
1015
1271
  cfg = {
1016
1272
  "model_type": model_type,
1017
1273
  "model_id": model_id,
@@ -2,5 +2,6 @@ from janus.converter.converter import Converter
2
2
  from janus.converter.diagram import DiagramGenerator
3
3
  from janus.converter.document import Documenter, MadLibsDocumenter, MultiDocumenter
4
4
  from janus.converter.evaluate import Evaluator
5
+ from janus.converter.partition import Partitioner
5
6
  from janus.converter.requirements import RequirementsDocumenter
6
7
  from janus.converter.translate import Translator
@@ -6,7 +6,12 @@ from typing import Any
6
6
 
7
7
  from langchain_core.exceptions import OutputParserException
8
8
  from langchain_core.prompts import ChatPromptTemplate
9
- from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
9
+ from langchain_core.runnables import (
10
+ Runnable,
11
+ RunnableLambda,
12
+ RunnableParallel,
13
+ RunnablePassthrough,
14
+ )
10
15
  from openai import BadRequestError, RateLimitError
11
16
  from pydantic import ValidationError
12
17
 
@@ -23,15 +28,14 @@ from janus.language.splitter import (
23
28
  from janus.llm.model_callbacks import get_model_callback
24
29
  from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
25
30
  from janus.parsers.parser import GenericParser, JanusParser
26
- from janus.refiners.refiner import (
27
- FixParserExceptions,
28
- HallucinationRefiner,
29
- JanusRefiner,
30
- ReflectionRefiner,
31
- )
31
+ from janus.refiners.refiner import JanusRefiner
32
32
 
33
33
  # from janus.refiners.refiner import BasicRefiner, Refiner
34
- from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
34
+ from janus.retrievers.retriever import (
35
+ ActiveUsingsRetriever,
36
+ JanusRetriever,
37
+ LanguageDocsRetriever,
38
+ )
35
39
  from janus.utils.enums import LANGUAGES
36
40
  from janus.utils.logger import create_logger
37
41
 
@@ -78,7 +82,7 @@ class Converter:
78
82
  protected_node_types: tuple[str, ...] = (),
79
83
  prune_node_types: tuple[str, ...] = (),
80
84
  splitter_type: str = "file",
81
- refiner_type: str | None = None,
85
+ refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
82
86
  retriever_type: str | None = None,
83
87
  ) -> None:
84
88
  """Initialize a Converter instance.
@@ -105,6 +109,7 @@ class Converter:
105
109
  - None
106
110
  retriever_type: The type of retriever to use. Valid values:
107
111
  - "active_usings"
112
+ - "language_docs"
108
113
  - None
109
114
  """
110
115
  self._changed_attrs: set = set()
@@ -133,10 +138,11 @@ class Converter:
133
138
  self._prompt: ChatPromptTemplate
134
139
 
135
140
  self._parser: JanusParser = GenericParser()
141
+ self._base_parser: JanusParser = GenericParser()
136
142
  self._combiner: Combiner = Combiner()
137
143
 
138
144
  self._splitter_type: str
139
- self._refiner_type: str | None
145
+ self._refiner_types: list[type[JanusRefiner]]
140
146
  self._retriever_type: str | None
141
147
 
142
148
  self._splitter: Splitter
@@ -144,7 +150,7 @@ class Converter:
144
150
  self._retriever: JanusRetriever
145
151
 
146
152
  self.set_splitter(splitter_type=splitter_type)
147
- self.set_refiner(refiner_type=refiner_type)
153
+ self.set_refiner_types(refiner_types=refiner_types)
148
154
  self.set_retriever(retriever_type=retriever_type)
149
155
  self.set_model(model_name=model, **model_arguments)
150
156
  self.set_prompt(prompt_template=prompt_template)
@@ -170,7 +176,7 @@ class Converter:
170
176
  self._load_model()
171
177
  self._load_prompt()
172
178
  self._load_retriever()
173
- self._load_refiner()
179
+ self._load_refiner_chain()
174
180
  self._load_splitter()
175
181
  self._load_vectorizer()
176
182
  self._load_chain()
@@ -210,13 +216,13 @@ class Converter:
210
216
 
211
217
  self._splitter_type = splitter_type
212
218
 
213
- def set_refiner(self, refiner_type: str | None) -> None:
219
+ def set_refiner_types(self, refiner_types: list[type[JanusRefiner]]) -> None:
214
220
  """Validate and set the refiner type
215
221
 
216
222
  Arguments:
217
223
  refiner_type: the type of refiner to use
218
224
  """
219
- self._refiner_type = refiner_type
225
+ self._refiner_types = refiner_types
220
226
 
221
227
  def set_retriever(self, retriever_type: str | None) -> None:
222
228
  """Validate and set the retriever type
@@ -355,48 +361,40 @@ class Converter:
355
361
  def _load_retriever(self):
356
362
  if self._retriever_type == "active_usings":
357
363
  self._retriever = ActiveUsingsRetriever()
364
+ elif self._retriever_type == "language_docs":
365
+ self._retriever = LanguageDocsRetriever(self._llm, self._source_language)
358
366
  else:
359
367
  self._retriever = JanusRetriever()
360
368
 
361
- @run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
362
- def _load_refiner(self) -> None:
363
- """Load the refiner according to this instance's attributes.
364
-
365
- If the relevant fields have not been changed since the last time this method was
366
- called, nothing happens.
367
- """
368
- if self._refiner_type == "parser":
369
- self._refiner = FixParserExceptions(
370
- llm=self._llm,
371
- parser=self._parser,
372
- max_retries=self.max_prompts,
373
- )
374
- elif self._refiner_type == "reflection":
375
- self._refiner = ReflectionRefiner(
376
- llm=self._llm,
377
- parser=self._parser,
378
- max_retries=self.max_prompts,
369
+ @run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
370
+ def _load_refiner_chain(self) -> None:
371
+ self._refiner_chain = RunnableParallel(
372
+ completion=self._llm,
373
+ prompt_value=RunnablePassthrough(),
374
+ )
375
+ for refiner_type in self._refiner_types[:-1]:
376
+ # NOTE: Do NOT remove refiner_type=refiner_type from lambda.
377
+ # Due to lambda capture, must be present or chain will not
378
+ # be correctly constructed.
379
+ self._refiner_chain = self._refiner_chain | RunnableParallel(
380
+ completion=lambda x, refiner_type=refiner_type: refiner_type(
381
+ llm=self._llm,
382
+ parser=self._base_parser,
383
+ max_retries=self.max_prompts,
384
+ ).parse_completion(**x),
385
+ prompt_value=lambda x: x["prompt_value"],
379
386
  )
380
- elif self._refiner_type == "hallucination":
381
- self._refiner = HallucinationRefiner(
387
+ self._refiner_chain = self._refiner_chain | RunnableLambda(
388
+ lambda x: self._refiner_types[-1](
382
389
  llm=self._llm,
383
390
  parser=self._parser,
384
391
  max_retries=self.max_prompts,
385
- )
386
- else:
387
- self._refiner = JanusRefiner(parser=self._parser)
392
+ ).parse_completion(**x)
393
+ )
388
394
 
389
- @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
395
+ @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
390
396
  def _load_chain(self):
391
- self.chain = (
392
- self._input_runnable()
393
- | self._prompt
394
- | RunnableParallel(
395
- completion=self._llm,
396
- prompt_value=RunnablePassthrough(),
397
- )
398
- | self._refiner.parse_runnable
399
- )
397
+ self.chain = self._input_runnable() | self._prompt | self._refiner_chain
400
398
 
401
399
  def _input_runnable(self) -> Runnable:
402
400
  return RunnableParallel(
@@ -466,6 +464,7 @@ class Converter:
466
464
  for in_path, out_path in in_out_pairs:
467
465
  # Translate the file, skip it if there's a rate limit error
468
466
  try:
467
+ log.info(f"Processing {in_path.relative_to(input_directory)}")
469
468
  out_block = self.translate_file(in_path)
470
469
  total_cost += out_block.total_cost
471
470
  except RateLimitError: