janus-llm 4.4.5__py3-none-any.whl → 4.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. janus/__init__.py +1 -1
  2. janus/cli/pipeline.py +6 -3
  3. janus/cli/self_eval.py +9 -0
  4. janus/converter/__init__.py +2 -0
  5. janus/converter/_tests/test_translate.py +1 -0
  6. janus/converter/chain.py +53 -133
  7. janus/converter/converter.py +199 -77
  8. janus/converter/diagram.py +5 -3
  9. janus/converter/document.py +10 -4
  10. janus/converter/evaluate.py +148 -113
  11. janus/converter/partition.py +4 -1
  12. janus/converter/passthrough.py +29 -0
  13. janus/converter/pool.py +74 -0
  14. janus/converter/requirements.py +4 -1
  15. janus/language/_tests/test_combine.py +1 -0
  16. janus/language/block.py +84 -3
  17. janus/llm/model_callbacks.py +6 -0
  18. janus/llm/models_info.py +19 -0
  19. janus/metrics/_tests/test_reading.py +48 -4
  20. janus/metrics/_tests/test_rouge_score.py +5 -11
  21. janus/metrics/reading.py +48 -28
  22. janus/metrics/rouge_score.py +21 -34
  23. janus/parsers/_tests/test_code_parser.py +1 -1
  24. janus/parsers/code_parser.py +2 -2
  25. janus/parsers/eval_parsers/incose_parser.py +3 -3
  26. janus/prompts/templates/cyclic/human.txt +16 -0
  27. janus/prompts/templates/cyclic/system.txt +1 -0
  28. janus/prompts/templates/eval_prompts/incose/human.txt +1 -1
  29. janus/prompts/templates/extract_variables/human.txt +5 -0
  30. janus/prompts/templates/extract_variables/system.txt +1 -0
  31. {janus_llm-4.4.5.dist-info → janus_llm-4.5.4.dist-info}/METADATA +3 -4
  32. {janus_llm-4.4.5.dist-info → janus_llm-4.5.4.dist-info}/RECORD +35 -29
  33. {janus_llm-4.4.5.dist-info → janus_llm-4.5.4.dist-info}/WHEEL +1 -1
  34. {janus_llm-4.4.5.dist-info → janus_llm-4.5.4.dist-info}/LICENSE +0 -0
  35. {janus_llm-4.4.5.dist-info → janus_llm-4.5.4.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import json
3
3
  import time
4
+ from copy import deepcopy
4
5
  from pathlib import Path
5
6
  from typing import Any
6
7
 
@@ -16,7 +17,7 @@ from openai import BadRequestError, RateLimitError
16
17
  from pydantic import ValidationError
17
18
 
18
19
  from janus.embedding.vectorize import ChromaDBVectorizer
19
- from janus.language.block import CodeBlock, TranslatedCodeBlock
20
+ from janus.language.block import BlockCollection, CodeBlock, TranslatedCodeBlock
20
21
  from janus.language.combine import Combiner
21
22
  from janus.language.naive.registry import CUSTOM_SPLITTERS
22
23
  from janus.language.splitter import (
@@ -82,12 +83,16 @@ class Converter:
82
83
  protected_node_types: tuple[str, ...] = (),
83
84
  prune_node_types: tuple[str, ...] = (),
84
85
  splitter_type: str = "file",
85
- refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
86
+ refiner_types: list[type[JanusRefiner] | str] = [JanusRefiner],
86
87
  retriever_type: str | None = None,
87
88
  combine_output: bool = True,
88
89
  use_janus_inputs: bool = False,
89
90
  target_language: str = "json",
90
91
  target_version: str | None = None,
92
+ input_types: set[str] | str | None = None,
93
+ input_labels: set[str] | str | None = None,
94
+ output_type: str | None = None,
95
+ output_label: str | None = None,
91
96
  ) -> None:
92
97
  """Initialize a Converter instance.
93
98
 
@@ -119,6 +124,10 @@ class Converter:
119
124
  use_janus_inputs: Whether to use janus inputs or not.
120
125
  target_language: The target programming language.
121
126
  target_version: The target programming language version.
127
+ input_types: The types of input to accept.
128
+ input_labels: The labels of input to accept.
129
+ output_type: The type of output to produce.
130
+ output_label: The label of output to produce.
122
131
  """
123
132
  self._changed_attrs: set = set()
124
133
 
@@ -154,7 +163,7 @@ class Converter:
154
163
  self._combiner: Combiner = Combiner()
155
164
 
156
165
  self._splitter_type: str
157
- self._refiner_types: list[type[JanusRefiner]]
166
+ self._refiner_types: list[type[JanusRefiner] | str]
158
167
  self._retriever_type: str | None
159
168
 
160
169
  self._splitter: Splitter
@@ -172,6 +181,13 @@ class Converter:
172
181
  self.set_db_path(db_path=db_path)
173
182
  self.set_db_config(db_config=db_config)
174
183
 
184
+ self._input_types = input_types
185
+ self._input_labels = input_labels
186
+ self._output_type = output_type
187
+ self._output_label = output_label
188
+
189
+ self._load_parameters()
190
+
175
191
  # Child class must call this. Should we enforce somehow?
176
192
  # self._load_parameters()
177
193
 
@@ -230,7 +246,7 @@ class Converter:
230
246
 
231
247
  self._splitter_type = splitter_type
232
248
 
233
- def set_refiner_types(self, refiner_types: list[type[JanusRefiner]]) -> None:
249
+ def set_refiner_types(self, refiner_types: list[type[JanusRefiner] | str]) -> None:
234
250
  """Validate and set the refiner type
235
251
 
236
252
  Arguments:
@@ -342,7 +358,13 @@ class Converter:
342
358
  if not self.override_token_limit:
343
359
  self._max_tokens = int(token_limit * self._llm.input_token_proportion)
344
360
 
345
- @run_if_changed("_prompt_template_names", "_source_language", "_model_name")
361
+ @run_if_changed(
362
+ "_prompt_template_names",
363
+ "_source_language",
364
+ "_model_name",
365
+ "_target_language",
366
+ "_target_version",
367
+ )
346
368
  def _load_translation_chain(self) -> None:
347
369
  prompt_template_name = self._prompt_template_names[0]
348
370
  prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
@@ -404,12 +426,18 @@ class Converter:
404
426
 
405
427
  @run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
406
428
  def _load_refiner_chain(self) -> None:
429
+ from janus.cli.constants import REFINERS
430
+
407
431
  if len(self._refiner_types) == 0:
408
432
  self._refiner_chain = RunnableLambda(
409
433
  lambda x: self._parser.parse(x["completion"])
410
434
  )
411
435
  return
412
436
  refiner_type = self._refiner_types[0]
437
+ if isinstance(refiner_type, str):
438
+ if refiner_type not in REFINERS:
439
+ raise ValueError(f"Error: unable to find refiner type {refiner_type}")
440
+ refiner_type = REFINERS[refiner_type]
413
441
  if len(self._refiner_types) == 1:
414
442
  self._refiner_chain = RunnableLambda(
415
443
  lambda x, refiner_type=refiner_type: refiner_type(
@@ -429,6 +457,10 @@ class Converter:
429
457
  prompt_value=lambda x: x["prompt_value"],
430
458
  )
431
459
  for refiner_type in self._refiner_types[1:-1]:
460
+ if isinstance(refiner_type, str):
461
+ if refiner_type not in REFINERS:
462
+ raise ValueError(f"Error: unable to find refiner type {refiner_type}")
463
+ refiner_type = REFINERS[refiner_type]
432
464
  # NOTE: Do NOT remove refiner_type=refiner_type from lambda.
433
465
  # Due to lambda capture, must be present or chain will not
434
466
  # be correctly constructed.
@@ -448,7 +480,15 @@ class Converter:
448
480
  ).parse_completion(**x)
449
481
  )
450
482
 
451
- @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
483
+ @run_if_changed(
484
+ "_parser",
485
+ "_retriever",
486
+ "_prompt",
487
+ "_llm",
488
+ "_refiner_chain",
489
+ "_target_language",
490
+ "_target_version",
491
+ )
452
492
  def _load_chain(self):
453
493
  self.chain = self.get_chain()
454
494
 
@@ -561,13 +601,12 @@ class Converter:
561
601
  # For files where translation failed, write to failure path instead
562
602
 
563
603
  def _has_empty(block):
564
- if isinstance(block, list):
565
- return len(block) == 0 or any(_has_empty(b) for b in block)
604
+ if isinstance(block, BlockCollection):
605
+ return len(block.blocks) == 0 or any(
606
+ _has_empty(b) for b in block.blocks
607
+ )
566
608
  return not block.translated
567
609
 
568
- while isinstance(out_block, list) and len(out_block) == 1:
569
- out_block = out_block[0]
570
-
571
610
  if _has_empty(out_block):
572
611
  if fail_path is not None:
573
612
  self._save_to_file(out_block, fail_path)
@@ -582,28 +621,58 @@ class Converter:
582
621
 
583
622
  # Make sure the tree's code has been consolidated at the top level
584
623
  # before writing to file
585
- self._combiner.combine(out_block)
624
+ for b in out_block.blocks:
625
+ self._combiner.combine(b)
586
626
  if out_path is not None and (overwrite or not out_path.exists()):
587
627
  self._save_to_file(out_block, out_path)
588
628
 
589
629
  log.info(f"Total cost: ${total_cost:,.2f}")
590
630
 
631
+ def _filter_blocks(self, code_block):
632
+ if isinstance(code_block, BlockCollection):
633
+ input_blocks = list(code_block.blocks)
634
+ else:
635
+ input_blocks = [code_block]
636
+ if self._input_types is not None:
637
+ if isinstance(self._input_types, str):
638
+ self._input_types = set([self._input_types])
639
+ input_blocks = [
640
+ b
641
+ for b in input_blocks
642
+ if isinstance(b, BlockCollection) or b.block_type in self._input_types
643
+ ]
644
+ if self._input_labels is not None:
645
+ if isinstance(self._input_labels, str):
646
+ self._input_labels = set([self._input_labels])
647
+ input_blocks = [
648
+ b
649
+ for b in input_blocks
650
+ if isinstance(b, BlockCollection) or b.block_label in self._input_labels
651
+ ]
652
+ return input_blocks
653
+
654
+ def translate_blocks(
655
+ self,
656
+ code_block: CodeBlock | BlockCollection,
657
+ failure_path: Path | None = None,
658
+ ) -> BlockCollection | TranslatedCodeBlock:
659
+ input_blocks = self._filter_blocks(code_block)
660
+ output_blocks = []
661
+ for b in input_blocks:
662
+ output_blocks.append(self.translate_block(b, failure_path))
663
+ return BlockCollection(output_blocks, code_block.previous_generations)
664
+
591
665
  def translate_block(
592
666
  self,
593
- input_block: CodeBlock | list[CodeBlock],
594
- name: str,
667
+ input_block: CodeBlock,
595
668
  failure_path: Path | None = None,
596
- ):
669
+ ) -> TranslatedCodeBlock:
597
670
  self._load_parameters()
598
- if isinstance(input_block, list):
599
- return [self.translate_block(b, name, failure_path) for b in input_block]
600
- t0 = time.time()
601
671
  output_block = self._iterative_translate(input_block, failure_path)
602
- output_block.processing_time = time.time() - t0
603
672
  if output_block.translated:
604
673
  completeness = output_block.translation_completeness
605
674
  log.info(
606
- f"[{name}] Translation complete\n"
675
+ f"[{output_block.name}] Translation complete\n"
607
676
  f" {completeness:.2%} of input successfully translated\n"
608
677
  f" Total cost: ${output_block.total_cost:,.2f}\n"
609
678
  f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
@@ -611,7 +680,7 @@ class Converter:
611
680
 
612
681
  else:
613
682
  log.error(
614
- f"[{name}] Translation failed\n"
683
+ f"[{output_block.name}] Translation failed\n"
615
684
  f" Total cost: ${output_block.total_cost:,.2f}\n"
616
685
  )
617
686
  return output_block
@@ -632,9 +701,8 @@ class Converter:
632
701
  code is not guaranteed to be consolidated. To amend this, run
633
702
  `Combiner.combine_children` on the block.
634
703
  """
635
- filename = file.name
636
704
  input_block = self._split_file(file)
637
- return self.translate_block(input_block, filename, failure_path)
705
+ return self.translate_blocks(input_block, failure_path)
638
706
 
639
707
  def translate_janus_file(self, file: Path, failure_path: Path | None = None):
640
708
  filename = file.name
@@ -644,7 +712,7 @@ class Converter:
644
712
 
645
713
  def translate_janus_obj(self, obj: Any, name: str, failure_path: Path | None = None):
646
714
  block = self._janus_object_to_codeblock(obj, name)
647
- return self.translate_block(block)
715
+ return self.translate_blocks(block, failure_path)
648
716
 
649
717
  def translate_text(self, text: str, name: str, failure_path: Path | None = None):
650
718
  """
@@ -655,7 +723,7 @@ class Converter:
655
723
  failure_path: path to write failure file if translation is not successful
656
724
  """
657
725
  input_block = self._split_text(text, name)
658
- return self.translate_block(input_block, name, failure_path)
726
+ return self.translate_blocks(input_block, failure_path)
659
727
 
660
728
  def _iterative_translate(
661
729
  self, root: CodeBlock, failure_path: Path | None = None
@@ -669,7 +737,13 @@ class Converter:
669
737
  Returns:
670
738
  A `TranslatedCodeBlock`
671
739
  """
672
- translated_root = TranslatedCodeBlock(root, self._target_language)
740
+ translated_root = TranslatedCodeBlock(
741
+ root,
742
+ self._target_language,
743
+ self,
744
+ block_type=self._output_type,
745
+ block_label=self._output_label,
746
+ )
673
747
  last_prog, prog_delta = 0, 0.1
674
748
  stack = [translated_root]
675
749
  try:
@@ -692,7 +766,7 @@ class Converter:
692
766
  except RateLimitError:
693
767
  pass
694
768
  except OutputParserException as e:
695
- log.error(f"Skipping file, failed to parse output: {e}.")
769
+ log.error(f"Skipping file, failed to parse output: {e}")
696
770
  except BadRequestError as e:
697
771
  if str(e).startswith("Detected an error in the prompt"):
698
772
  log.warning("Malformed input, skipping")
@@ -720,7 +794,9 @@ class Converter:
720
794
  )
721
795
  raise e
722
796
  finally:
723
- out_obj = self._get_output_obj(translated_root, self._combine_output)
797
+ out_obj = self._get_output_obj(
798
+ translated_root, self._combine_output, include_previous_outputs=True
799
+ )
724
800
  log.debug(f"Resulting Block:" f"{json.dumps(out_obj)}")
725
801
  if not translated_root.translated:
726
802
  if failure_path is not None:
@@ -810,65 +886,105 @@ class Converter:
810
886
  input_tokens=sum(m["input_tokens"] for m in metadatas),
811
887
  output_tokens=sum(m["output_tokens"] for m in metadatas),
812
888
  converter_name=self.__class__.__name__,
889
+ type=[m["type"] for m in metadatas],
890
+ label=[m["label"] for m in metadatas],
813
891
  )
814
892
 
815
893
  def _combine_inputs(self, inputs: list[str]):
816
- s = ""
817
- for i in inputs:
818
- s += i
819
- return s
894
+ return json.dumps(inputs)
820
895
 
821
896
  def _get_output_obj(
822
- self, block: TranslatedCodeBlock | list, combine_children: bool = True
897
+ self,
898
+ block: TranslatedCodeBlock | BlockCollection | dict,
899
+ combine_children: bool = True,
900
+ include_previous_outputs: bool = True,
823
901
  ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
824
- if isinstance(block, list):
825
- # TODO: run on all items in list
826
- outputs = [self._get_output_obj(b, combine_children) for b in block]
827
- metadata = self._combine_metadata([o["metadata"] for o in outputs])
828
- input_agg = self._combine_inputs(o["input"] for o in outputs)
829
- return dict(
830
- input=input_agg,
831
- metadata=metadata,
832
- outputs=outputs,
833
- )
834
- if not combine_children and len(block.children) > 0:
835
- outputs = self._get_output_obj_children(block)
836
- metadata = self._combine_metadata([o["metadata"] for o in outputs])
837
- input_agg = self._combine_inputs(o["input"] for o in outputs)
838
- return dict(
839
- input=input_agg,
840
- metadata=metadata,
841
- outputs=outputs,
842
- )
843
- output_obj: str | dict[str, str]
844
- if not block.translation_completed:
845
- # translation wasn't completed, so combined parsing will likely fail
846
- output_obj = [block.complete_text]
902
+ block_type = None
903
+ block_label = None
904
+ if isinstance(block, dict):
905
+ # output object has already been generated
906
+ new_block = deepcopy(block)
907
+ if "intermediate_outputs" in new_block:
908
+ del new_block["intermediate_outputs"]
909
+ return new_block
910
+ if isinstance(block, BlockCollection):
911
+ if len(block.blocks) == 1:
912
+ outputs = self._get_output_obj(block.blocks[0], combine_children, False)[
913
+ "outputs"
914
+ ]
915
+ block_type = block.blocks[0].block_type
916
+ block_label = block.blocks[0].block_label
917
+ else:
918
+ outputs = [
919
+ self._get_output_obj(b, combine_children, False) for b in block.blocks
920
+ ]
921
+ elif (
922
+ not isinstance(block, BlockCollection)
923
+ and not combine_children
924
+ and len(block.children) > 0
925
+ ):
926
+ outputs = self._get_output_obj_children(block, False)
847
927
  else:
848
- output_str = self._parser.parse_combined_output(block.complete_text)
849
- output_obj = [output_str]
928
+ block_type = block.block_type
929
+ block_label = block.block_label
930
+ if not block.translation_completed:
931
+ # translation wasn't completed, so combined parsing will likely fail
932
+ outputs = [block.complete_text]
933
+ else:
934
+ output_str = self._parser.parse_combined_output(block.complete_text)
935
+ outputs = [output_str]
850
936
 
851
- return dict(
852
- input=block.original.text or "",
937
+ def _get_input(block):
938
+ if isinstance(block, BlockCollection):
939
+ return self._combine_inputs([_get_input(b) for b in block.blocks])
940
+ return block.original.text or ""
941
+
942
+ out = dict(
943
+ input=_get_input(block),
853
944
  metadata=dict(
854
945
  cost=block.total_cost,
855
- processing_time=block.processing_time,
946
+ processing_time=block.total_processing_time,
856
947
  num_requests=block.total_num_requests,
857
948
  input_tokens=block.total_request_input_tokens,
858
949
  output_tokens=block.total_request_output_tokens,
859
950
  converter_name=self.__class__.__name__,
951
+ type=block_type,
952
+ label=block_label,
860
953
  ),
861
- outputs=output_obj,
954
+ outputs=outputs,
862
955
  )
863
-
864
- def _get_output_obj_children(self, block: TranslatedCodeBlock):
956
+ if (
957
+ include_previous_outputs
958
+ and isinstance(block, BlockCollection)
959
+ and len(block.previous_generations) > 0
960
+ ):
961
+ intermediate_outputs = []
962
+ for p in block.previous_generations:
963
+ if isinstance(p, dict):
964
+ # preserve intermediate outputs from previous runs
965
+ intermediate_outputs.append(
966
+ self._get_output_obj(p, combine_children, False)
967
+ )
968
+ if len(intermediate_outputs) > 0:
969
+ out["intermediate_outputs"] = intermediate_outputs
970
+ return out
971
+
972
+ def _get_output_obj_children(
973
+ self, block: TranslatedCodeBlock, include_previous_outputs: bool = True
974
+ ):
865
975
  if len(block.children) > 0:
866
976
  res = []
867
977
  for c in block.children:
868
- res += self._get_output_obj_children(c)
978
+ res += self._get_output_obj_children(c, include_previous_outputs)
869
979
  return res
870
980
  else:
871
- return [self._get_output_obj(block, combine_children=True)]
981
+ return [
982
+ self._get_output_obj(
983
+ block,
984
+ combine_children=True,
985
+ include_previous_outputs=include_previous_outputs,
986
+ )
987
+ ]
872
988
 
873
989
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
874
990
  """Save a file to disk.
@@ -876,30 +992,36 @@ class Converter:
876
992
  Arguments:
877
993
  block: The `TranslatedCodeBlock` to save to a file.
878
994
  """
879
- obj = self._get_output_obj(block, combine_children=self._combine_output)
995
+ obj = self._get_output_obj(
996
+ block, combine_children=self._combine_output, include_previous_outputs=True
997
+ )
880
998
  out_path.parent.mkdir(parents=True, exist_ok=True)
881
999
  out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
882
1000
 
883
1001
  def _janus_object_to_codeblock(self, janus_obj: dict, name: str):
884
1002
  results = []
885
1003
  for o in janus_obj["outputs"]:
1004
+ metadata = janus_obj["metadata"]
886
1005
  if isinstance(o, str):
1006
+ block_label = metadata["label"]
1007
+ if isinstance(block_label, list):
1008
+ block_label = block_label[0]
1009
+ block_type = metadata["type"]
1010
+ if isinstance(block_type, list):
1011
+ block_type = block_type[0]
887
1012
  code_block = self._split_text(o, name)
888
- meta_data = janus_obj["metadata"]
889
- code_block.initial_cost = meta_data["cost"]
890
- code_block.initial_input_tokens = meta_data["input_tokens"]
891
- code_block.initial_output_tokens = meta_data["output_tokens"]
892
- code_block.initial_num_requests = meta_data["num_requests"]
893
- code_block.initial_processing_time = meta_data["processing_time"]
894
1013
  code_block.previous_generations = janus_obj.get(
895
1014
  "intermediate_outputs", []
896
1015
  ) + [janus_obj]
1016
+ code_block.block_type = block_type
1017
+ code_block.block_label = block_label
897
1018
  results.append(code_block)
898
1019
  else:
899
- results.append(self._janus_object_to_codeblock(o))
900
- while isinstance(results, list) and len(results) == 1:
901
- results = results[0]
902
- return results
1020
+ results += self._janus_object_to_codeblock(o, name).blocks
1021
+ previous_generations = janus_obj.get("intermediate_outputs", [])
1022
+ if janus_obj["metadata"]["converter_name"] != "ConverterChain":
1023
+ previous_generations += [janus_obj]
1024
+ return BlockCollection(results, previous_generations)
903
1025
 
904
1026
  def __or__(self, other: "Converter"):
905
1027
  from janus.converter.chain import ConverterChain
@@ -12,9 +12,10 @@ class DiagramGenerator(Documenter):
12
12
 
13
13
  def __init__(
14
14
  self,
15
- diagram_type="Activity",
16
- add_documentation=False,
17
- extract_variables=False,
15
+ diagram_type: str = "Activity",
16
+ add_documentation: bool = False,
17
+ extract_variables: bool = False,
18
+ output_type: str = "diagram",
18
19
  **kwargs,
19
20
  ) -> None:
20
21
  """Initialize the DiagramGenerator class
@@ -28,6 +29,7 @@ class DiagramGenerator(Documenter):
28
29
  self._add_documentation = add_documentation
29
30
  self._documenter = Documenter(**kwargs)
30
31
 
32
+ kwargs.update(dict(output_type=output_type))
31
33
  super().__init__(**kwargs)
32
34
  prompts = []
33
35
  if extract_variables:
@@ -15,9 +15,13 @@ log = create_logger(__name__)
15
15
 
16
16
  class Documenter(Converter):
17
17
  def __init__(
18
- self, source_language: str = "fortran", drop_comments: bool = True, **kwargs
18
+ self,
19
+ source_language: str = "fortran",
20
+ drop_comments: bool = True,
21
+ output_type: str = "documentation",
22
+ **kwargs,
19
23
  ):
20
- kwargs.update(source_language=source_language)
24
+ kwargs.update(source_language=source_language, output_type=output_type)
21
25
  super().__init__(**kwargs)
22
26
  self.set_prompts("document")
23
27
 
@@ -31,7 +35,8 @@ class Documenter(Converter):
31
35
 
32
36
 
33
37
  class MultiDocumenter(Documenter):
34
- def __init__(self, **kwargs):
38
+ def __init__(self, output_type: str = "multidocumentation", **kwargs):
39
+ kwargs.update(output_type=output_type)
35
40
  super().__init__(**kwargs)
36
41
  self.set_prompts("multidocument")
37
42
  self._combiner = JsonCombiner()
@@ -44,9 +49,10 @@ class ClozeDocumenter(Documenter):
44
49
  def __init__(
45
50
  self,
46
51
  comments_per_request: int | None = None,
52
+ output_type: str = "cloze_comments",
47
53
  **kwargs,
48
54
  ) -> None:
49
- kwargs.update(drop_comments=False)
55
+ kwargs.update(drop_comments=False, output_type=output_type)
50
56
  super().__init__(**kwargs)
51
57
  self.set_prompts("document_cloze")
52
58
  self._combiner = JsonCombiner()