janus-llm 4.3.5__py3-none-any.whl → 4.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,7 +76,7 @@ class Converter:
76
76
  source_language: str = "fortran",
77
77
  max_prompts: int = 10,
78
78
  max_tokens: int | None = None,
79
- prompt_template: str = "simple",
79
+ prompt_templates: list[str] | str = ["simple"],
80
80
  db_path: str | None = None,
81
81
  db_config: dict[str, Any] | None = None,
82
82
  protected_node_types: tuple[str, ...] = (),
@@ -84,6 +84,10 @@ class Converter:
84
84
  splitter_type: str = "file",
85
85
  refiner_types: list[type[JanusRefiner]] = [JanusRefiner],
86
86
  retriever_type: str | None = None,
87
+ combine_output: bool = True,
88
+ use_janus_inputs: bool = False,
89
+ target_language: str = "json",
90
+ target_version: str | None = None,
87
91
  ) -> None:
88
92
  """Initialize a Converter instance.
89
93
 
@@ -96,7 +100,7 @@ class Converter:
96
100
  max_prompts: The maximum number of prompts to try before giving up.
97
101
  max_tokens: The maximum number of tokens to use in the LLM. If `None`, the
98
102
  converter will use half the model's token limit.
99
- prompt_template: The name of the prompt template to use.
103
+ prompt_templates: The name of the prompt templates to use.
100
104
  db_path: The path to the database to use for vectorization.
101
105
  db_config: The configuration for the database.
102
106
  protected_node_types: A set of node types that aren't to be merged.
@@ -111,12 +115,17 @@ class Converter:
111
115
  - "active_usings"
112
116
  - "language_docs"
113
117
  - None
118
+ combine_output: Whether to combine the output into a single file or not.
119
+ use_janus_inputs: Whether to use janus inputs or not.
120
+ target_language: The target programming language.
121
+ target_version: The target programming language version.
114
122
  """
115
123
  self._changed_attrs: set = set()
116
124
 
117
125
  self.max_prompts: int = max_prompts
118
126
  self._max_tokens: int | None = max_tokens
119
127
  self.override_token_limit: bool = max_tokens is not None
128
+ self._combine_output = combine_output
120
129
 
121
130
  self._model_name: str
122
131
  self._custom_model_arguments: dict[str, Any]
@@ -124,13 +133,16 @@ class Converter:
124
133
  self._source_language: str
125
134
  self._source_suffixes: list[str]
126
135
 
127
- self._target_language = "json"
128
- self._target_suffix = ".json"
136
+ self._target_language: str
137
+ self._target_suffix: str
138
+ self._target_version: str | None
139
+ self.set_target_language(target_language, target_version)
140
+ self._use_janus_inputs = use_janus_inputs
129
141
 
130
142
  self._protected_node_types: tuple[str, ...] = ()
131
143
  self._prune_node_types: tuple[str, ...] = ()
132
144
  self._max_tokens: int | None = max_tokens
133
- self._prompt_template_name: str
145
+ self._prompt_template_names: list[str]
134
146
  self._db_path: str | None
135
147
  self._db_config: dict[str, Any] | None
136
148
 
@@ -153,7 +165,7 @@ class Converter:
153
165
  self.set_refiner_types(refiner_types=refiner_types)
154
166
  self.set_retriever(retriever_type=retriever_type)
155
167
  self.set_model(model_name=model, **model_arguments)
156
- self.set_prompt(prompt_template=prompt_template)
168
+ self.set_prompts(prompt_templates=prompt_templates)
157
169
  self.set_source_language(source_language)
158
170
  self.set_protected_node_types(protected_node_types)
159
171
  self.set_prune_node_types(prune_node_types)
@@ -174,7 +186,7 @@ class Converter:
174
186
 
175
187
  def _load_parameters(self) -> None:
176
188
  self._load_model()
177
- self._load_prompt()
189
+ self._load_translation_chain()
178
190
  self._load_retriever()
179
191
  self._load_refiner_chain()
180
192
  self._load_splitter()
@@ -195,21 +207,23 @@ class Converter:
195
207
  self._model_name = model_name
196
208
  self._custom_model_arguments = custom_arguments
197
209
 
198
- def set_prompt(self, prompt_template: str) -> None:
210
+ def set_prompts(self, prompt_templates: list[str] | str) -> None:
199
211
  """Validate and set the prompt template name.
200
212
 
201
213
  Arguments:
202
- prompt_template: name of prompt template directory
203
- (see janus/prompts/templates) or path to a directory.
214
+ prompt_templates: name of prompt template directories
215
+ (see janus/prompts/templates) or paths to directories.
204
216
  """
205
- self._prompt_template_name = prompt_template
217
+ if isinstance(prompt_templates, str):
218
+ self._prompt_template_names = [prompt_templates]
219
+ else:
220
+ self._prompt_template_names = prompt_templates
206
221
 
207
222
  def set_splitter(self, splitter_type: str) -> None:
208
223
  """Validate and set the prompt template name.
209
224
 
210
225
  Arguments:
211
- prompt_template: name of prompt template directory
212
- (see janus/prompts/templates) or path to a directory.
226
+ splitter_type: the type of splitter to use
213
227
  """
214
228
  if splitter_type not in CUSTOM_SPLITTERS:
215
229
  raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
@@ -328,26 +342,46 @@ class Converter:
328
342
  if not self.override_token_limit:
329
343
  self._max_tokens = int(token_limit * self._llm.input_token_proportion)
330
344
 
331
- @run_if_changed(
332
- "_prompt_template_name",
333
- "_source_language",
334
- "_model_name",
335
- "_parser",
336
- )
337
- def _load_prompt(self) -> None:
338
- """Load the prompt according to this instance's attributes.
339
-
340
- If the relevant fields have not been changed since the last time this
341
- method was called, nothing happens.
342
- """
345
+ @run_if_changed("_prompt_template_names", "_source_language", "_model_name")
346
+ def _load_translation_chain(self) -> None:
347
+ prompt_template_name = self._prompt_template_names[0]
343
348
  prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
344
349
  source_language=self._source_language,
345
- prompt_template=self._prompt_template_name,
350
+ prompt_template=prompt_template_name,
351
+ target_language=self._target_language,
352
+ target_version=self._target_version,
346
353
  )
347
- self._prompt = prompt_engine.prompt
348
- self._prompt = self._prompt.partial(
349
- format_instructions=self._parser.get_format_instructions()
354
+ prompt = prompt_engine.prompt
355
+ self._translation_chain = RunnableParallel(
356
+ prompt_value=lambda x, prompt=prompt: prompt.invoke(x),
357
+ original_inputs=RunnablePassthrough(),
358
+ ) | RunnableParallel(
359
+ completion=lambda x: self._llm.invoke(x["prompt_value"]),
360
+ original_inputs=lambda x: x["original_inputs"],
361
+ prompt_value=lambda x: x["prompt_value"],
350
362
  )
363
+ for prompt_template_name in self._prompt_template_names[1:]:
364
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
365
+ source_language=self._source_language,
366
+ prompt_template=prompt_template_name,
367
+ target_language=self._target_language,
368
+ target_version=self._target_version,
369
+ )
370
+ prompt = prompt_engine.prompt
371
+ self._translation_chain = (
372
+ self._translation_chain
373
+ | RunnableParallel(
374
+ prompt_value=lambda x, prompt=prompt: prompt.invoke(
375
+ dict(completion=x["completion"], **x["original_inputs"])
376
+ ),
377
+ original_inputs=lambda x: x["original_inputs"],
378
+ )
379
+ | RunnableParallel(
380
+ completion=lambda x: self._llm.invoke(x["prompt_value"]),
381
+ original_inputs=lambda x: x["original_inputs"],
382
+ prompt_value=lambda x: x["prompt_value"],
383
+ )
384
+ )
351
385
 
352
386
  @run_if_changed("_db_path", "_db_config")
353
387
  def _load_vectorizer(self) -> None:
@@ -370,11 +404,31 @@ class Converter:
370
404
 
371
405
  @run_if_changed("_refiner_types", "_model_name", "max_prompts", "_parser")
372
406
  def _load_refiner_chain(self) -> None:
373
- self._refiner_chain = RunnableParallel(
374
- completion=self._llm,
375
- prompt_value=RunnablePassthrough(),
376
- )
377
- for refiner_type in self._refiner_types[:-1]:
407
+ if len(self._refiner_types) == 0:
408
+ self._refiner_chain = RunnableLambda(
409
+ lambda x: self._parser.parse(x["completion"])
410
+ )
411
+ return
412
+ refiner_type = self._refiner_types[0]
413
+ if len(self._refiner_types) == 1:
414
+ self._refiner_chain = RunnableLambda(
415
+ lambda x, refiner_type=refiner_type: refiner_type(
416
+ llm=self._llm,
417
+ parser=self._parser,
418
+ max_retries=self.max_prompts,
419
+ ).parse_completion(**x)
420
+ )
421
+ return
422
+ else:
423
+ self._refiner_chain = RunnableParallel(
424
+ completion=lambda x, refiner_type=refiner_type: refiner_type(
425
+ llm=self._llm,
426
+ parser=self._base_parser,
427
+ max_retries=self.max_prompts,
428
+ ).parse_completion(**x),
429
+ prompt_value=lambda x: x["prompt_value"],
430
+ )
431
+ for refiner_type in self._refiner_types[1:-1]:
378
432
  # NOTE: Do NOT remove refiner_type=refiner_type from lambda.
379
433
  # Due to lambda capture, must be present or chain will not
380
434
  # be correctly constructed.
@@ -396,7 +450,7 @@ class Converter:
396
450
 
397
451
  @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner_chain")
398
452
  def _load_chain(self):
399
- self.chain = self._input_runnable() | self._prompt | self._refiner_chain
453
+ self.chain = self.get_chain()
400
454
 
401
455
  def _input_runnable(self) -> Runnable:
402
456
  return RunnableParallel(
@@ -404,6 +458,12 @@ class Converter:
404
458
  context=self._retriever,
405
459
  )
406
460
 
461
+ def get_chain(self) -> Runnable:
462
+ """
463
+ Gets a chain that can be executed by langchain
464
+ """
465
+ return self._input_runnable() | self._translation_chain | self._refiner_chain
466
+
407
467
  def translate(
408
468
  self,
409
469
  input_directory: str | Path,
@@ -436,22 +496,24 @@ class Converter:
436
496
  failure_directory.mkdir(parents=True)
437
497
 
438
498
  input_paths = []
439
- for ext in self._source_suffixes:
499
+ if self._use_janus_inputs:
500
+ source_language = "janus"
501
+ source_suffixes = [".json"]
502
+ else:
503
+ source_language = self._source_language
504
+ source_suffixes = self._source_suffixes
505
+ for ext in source_suffixes:
440
506
  input_paths.extend(input_directory.rglob(f"**/*{ext}"))
441
507
 
442
508
  log.info(f"Input directory: {input_directory.absolute()}")
443
- log.info(
444
- f"{self._source_language} {self._source_suffixes} files: "
445
- f"{len(input_paths)}"
446
- )
509
+ log.info(f"{source_language} {source_suffixes} files: " f"{len(input_paths)}")
447
510
  log.info(
448
511
  "Other files (skipped): "
449
512
  f"{len(list(input_directory.iterdir())) - len(input_paths)}\n"
450
513
  )
451
514
  if output_directory is not None:
452
515
  output_paths = [
453
- output_directory
454
- / p.relative_to(input_directory).with_suffix(self._target_suffix)
516
+ output_directory / p.relative_to(input_directory).with_suffix(".json")
455
517
  for p in input_paths
456
518
  ]
457
519
  else:
@@ -459,8 +521,7 @@ class Converter:
459
521
 
460
522
  if failure_directory is not None:
461
523
  failure_paths = [
462
- failure_directory
463
- / p.relative_to(input_directory).with_suffix(self._target_suffix)
524
+ failure_directory / p.relative_to(input_directory).with_suffix(".json")
464
525
  for p in input_paths
465
526
  ]
466
527
  else:
@@ -484,12 +545,32 @@ class Converter:
484
545
  for in_path, out_path, fail_path in in_out_pairs:
485
546
  # Translate the file, skip it if there's a rate limit error
486
547
  log.info(f"Processing {in_path.relative_to(input_directory)}")
487
- out_block = self.translate_file(in_path, fail_path)
488
- total_cost += out_block.total_cost
548
+ if self._use_janus_inputs:
549
+ out_block = self.translate_janus_file(in_path, fail_path)
550
+ else:
551
+ out_block = self.translate_file(in_path, fail_path)
552
+
553
+ def _get_total_cost(block):
554
+ if isinstance(block, list):
555
+ return sum(_get_total_cost(b) for b in block)
556
+ return block.total_cost
557
+
558
+ total_cost += _get_total_cost(out_block)
489
559
  log.info(f"Current Running Cost: {total_cost}")
490
560
 
491
- # Don't attempt to write files for which translation failed
492
- if not out_block.translated:
561
+ # For files where translation failed, write to failure path instead
562
+
563
+ def _has_empty(block):
564
+ if isinstance(block, list):
565
+ return len(block) == 0 or any(_has_empty(b) for b in block)
566
+ return not block.translated
567
+
568
+ while isinstance(out_block, list) and len(out_block) == 1:
569
+ out_block = out_block[0]
570
+
571
+ if _has_empty(out_block):
572
+ if fail_path is not None:
573
+ self._save_to_file(out_block, fail_path)
493
574
  continue
494
575
 
495
576
  if collection_name is not None:
@@ -507,31 +588,22 @@ class Converter:
507
588
 
508
589
  log.info(f"Total cost: ${total_cost:,.2f}")
509
590
 
510
- def translate_file(
511
- self, file: Path, failure_path: Path | None = None
512
- ) -> TranslatedCodeBlock:
513
- """Translate a single file.
514
-
515
- Arguments:
516
- file: Input path to file
517
- failure_path: path to directory to store failure summaries`
518
-
519
- Returns:
520
- A `TranslatedCodeBlock` object. This block does not have a path set, and its
521
- code is not guaranteed to be consolidated. To amend this, run
522
- `Combiner.combine_children` on the block.
523
- """
591
+ def translate_block(
592
+ self,
593
+ input_block: CodeBlock | list[CodeBlock],
594
+ name: str,
595
+ failure_path: Path | None = None,
596
+ ):
524
597
  self._load_parameters()
525
- filename = file.name
526
-
527
- input_block = self._split_file(file)
598
+ if isinstance(input_block, list):
599
+ return [self.translate_block(b, name, failure_path) for b in input_block]
528
600
  t0 = time.time()
529
601
  output_block = self._iterative_translate(input_block, failure_path)
530
602
  output_block.processing_time = time.time() - t0
531
603
  if output_block.translated:
532
604
  completeness = output_block.translation_completeness
533
605
  log.info(
534
- f"[{filename}] Translation complete\n"
606
+ f"[{name}] Translation complete\n"
535
607
  f" {completeness:.2%} of input successfully translated\n"
536
608
  f" Total cost: ${output_block.total_cost:,.2f}\n"
537
609
  f" Output CodeBlock Structure:\n{input_block.tree_str()}\n"
@@ -539,11 +611,52 @@ class Converter:
539
611
 
540
612
  else:
541
613
  log.error(
542
- f"[{filename}] Translation failed\n"
614
+ f"[{name}] Translation failed\n"
543
615
  f" Total cost: ${output_block.total_cost:,.2f}\n"
544
616
  )
545
617
  return output_block
546
618
 
619
+ def translate_file(
620
+ self,
621
+ file: Path,
622
+ failure_path: Path | None = None,
623
+ ) -> TranslatedCodeBlock:
624
+ """Translate a single file.
625
+
626
+ Arguments:
627
+ file: Input path to file
628
+ failure_path: path to directory to store failure summaries`
629
+
630
+ Returns:
631
+ A `TranslatedCodeBlock` object. This block does not have a path set, and its
632
+ code is not guaranteed to be consolidated. To amend this, run
633
+ `Combiner.combine_children` on the block.
634
+ """
635
+ filename = file.name
636
+ input_block = self._split_file(file)
637
+ return self.translate_block(input_block, filename, failure_path)
638
+
639
+ def translate_janus_file(self, file: Path, failure_path: Path | None = None):
640
+ filename = file.name
641
+ with open(file, "r") as f:
642
+ file_obj = json.load(f)
643
+ return self.translate_janus_obj(file_obj, filename, failure_path)
644
+
645
+ def translate_janus_obj(self, obj: Any, name: str, failure_path: Path | None = None):
646
+ block = self._janus_object_to_codeblock(obj, name)
647
+ return self.translate_block(block)
648
+
649
+ def translate_text(self, text: str, name: str, failure_path: Path | None = None):
650
+ """
651
+ Translates given text
652
+ Arguments:
653
+ text: text to translate
654
+ name: the name of the text (filename if from a file)
655
+ failure_path: path to write failure file if translation is not successful
656
+ """
657
+ input_block = self._split_text(text, name)
658
+ return self.translate_block(input_block, name, failure_path)
659
+
547
660
  def _iterative_translate(
548
661
  self, root: CodeBlock, failure_path: Path | None = None
549
662
  ) -> TranslatedCodeBlock:
@@ -607,9 +720,8 @@ class Converter:
607
720
  )
608
721
  raise e
609
722
  finally:
610
- log.debug(
611
- f"Resulting Block: {json.dumps(self._get_output_obj(translated_root))}"
612
- )
723
+ out_obj = self._get_output_obj(translated_root, self._combine_output)
724
+ log.debug(f"Resulting Block:" f"{json.dumps(out_obj)}")
613
725
  if not translated_root.translated:
614
726
  if failure_path is not None:
615
727
  self._save_to_file(translated_root, failure_path)
@@ -666,6 +778,16 @@ class Converter:
666
778
 
667
779
  log.debug(f"[{block.name}] Output code:\n{block.text}")
668
780
 
781
+ def _split_text(self, text: str, name: str) -> CodeBlock:
782
+ log.info(f"[{name}] Splitting text")
783
+ root = self._splitter.split_string(text, name)
784
+ log.info(
785
+ f"[{name}] Text split into {root.n_descendents:,} blocks,"
786
+ f"tree of height {root.height}"
787
+ )
788
+ log.info(f"[{name}] Input CodeBlock Structure:\n{root.tree_str()}")
789
+ return root
790
+
669
791
  def _split_file(self, file: Path) -> CodeBlock:
670
792
  filename = file.name
671
793
  log.info(f"[{filename}] Splitting file")
@@ -680,19 +802,51 @@ class Converter:
680
802
  def _run_chain(self, block: TranslatedCodeBlock) -> str:
681
803
  return self.chain.invoke(block.original)
682
804
 
805
+ def _combine_metadata(self, metadatas: list[dict]):
806
+ return dict(
807
+ cost=sum(m["cost"] for m in metadatas),
808
+ processing_time=sum(m["processing_time"] for m in metadatas),
809
+ num_requests=sum(m["num_requests"] for m in metadatas),
810
+ input_tokens=sum(m["input_tokens"] for m in metadatas),
811
+ output_tokens=sum(m["output_tokens"] for m in metadatas),
812
+ converter_name=self.__class__.__name__,
813
+ )
814
+
815
+ def _combine_inputs(self, inputs: list[str]):
816
+ s = ""
817
+ for i in inputs:
818
+ s += i
819
+ return s
820
+
683
821
  def _get_output_obj(
684
- self, block: TranslatedCodeBlock
822
+ self, block: TranslatedCodeBlock | list, combine_children: bool = True
685
823
  ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
824
+ if isinstance(block, list):
825
+ # TODO: run on all items in list
826
+ outputs = [self._get_output_obj(b, combine_children) for b in block]
827
+ metadata = self._combine_metadata([o["metadata"] for o in outputs])
828
+ input_agg = self._combine_inputs(o["input"] for o in outputs)
829
+ return dict(
830
+ input=input_agg,
831
+ metadata=metadata,
832
+ outputs=outputs,
833
+ )
834
+ if not combine_children and len(block.children) > 0:
835
+ outputs = self._get_output_obj_children(block)
836
+ metadata = self._combine_metadata([o["metadata"] for o in outputs])
837
+ input_agg = self._combine_inputs(o["input"] for o in outputs)
838
+ return dict(
839
+ input=input_agg,
840
+ metadata=metadata,
841
+ outputs=outputs,
842
+ )
686
843
  output_obj: str | dict[str, str]
687
844
  if not block.translation_completed:
688
845
  # translation wasn't completed, so combined parsing will likely fail
689
- output_obj = block.complete_text
846
+ output_obj = [block.complete_text]
690
847
  else:
691
848
  output_str = self._parser.parse_combined_output(block.complete_text)
692
- try:
693
- output_obj = json.loads(output_str)
694
- except json.JSONDecodeError:
695
- output_obj = output_str
849
+ output_obj = [output_str]
696
850
 
697
851
  return dict(
698
852
  input=block.original.text or "",
@@ -702,16 +856,117 @@ class Converter:
702
856
  num_requests=block.total_num_requests,
703
857
  input_tokens=block.total_request_input_tokens,
704
858
  output_tokens=block.total_request_output_tokens,
859
+ converter_name=self.__class__.__name__,
705
860
  ),
706
- output=output_obj,
861
+ outputs=output_obj,
707
862
  )
708
863
 
864
+ def _get_output_obj_children(self, block: TranslatedCodeBlock):
865
+ if len(block.children) > 0:
866
+ res = []
867
+ for c in block.children:
868
+ res += self._get_output_obj_children(c)
869
+ return res
870
+ else:
871
+ return [self._get_output_obj(block, combine_children=True)]
872
+
709
873
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
710
874
  """Save a file to disk.
711
875
 
712
876
  Arguments:
713
877
  block: The `TranslatedCodeBlock` to save to a file.
714
878
  """
715
- obj = self._get_output_obj(block)
879
+ obj = self._get_output_obj(block, combine_children=self._combine_output)
716
880
  out_path.parent.mkdir(parents=True, exist_ok=True)
717
881
  out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
882
+
883
+ def _janus_object_to_codeblock(self, janus_obj: dict, name: str):
884
+ results = []
885
+ for o in janus_obj["outputs"]:
886
+ if isinstance(o, str):
887
+ code_block = self._split_text(o, name)
888
+ meta_data = janus_obj["metadata"]
889
+ code_block.initial_cost = meta_data["cost"]
890
+ code_block.initial_input_tokens = meta_data["input_tokens"]
891
+ code_block.initial_output_tokens = meta_data["output_tokens"]
892
+ code_block.initial_num_requests = meta_data["num_requests"]
893
+ code_block.initial_processing_time = meta_data["processing_time"]
894
+ code_block.previous_generations = janus_obj.get(
895
+ "intermediate_outputs", []
896
+ ) + [janus_obj]
897
+ results.append(code_block)
898
+ else:
899
+ results.append(self._janus_object_to_codeblock(o))
900
+ while isinstance(results, list) and len(results) == 1:
901
+ results = results[0]
902
+ return results
903
+
904
+ def __or__(self, other: "Converter"):
905
+ from janus.converter.chain import ConverterChain
906
+
907
+ return ConverterChain(self, other)
908
+
909
+ @property
910
+ def source_language(self):
911
+ return self._source_language
912
+
913
+ @property
914
+ def target_language(self):
915
+ return self._target_language
916
+
917
+ @property
918
+ def target_version(self):
919
+ return self._target_version
920
+
921
+ def set_target_language(
922
+ self, target_language: str, target_version: str | None
923
+ ) -> None:
924
+ """Validate and set the target language.
925
+
926
+ The affected objects will not be updated until translate() is called.
927
+
928
+ Arguments:
929
+ target_language: The target programming language.
930
+ target_version: The target version of the target programming language.
931
+ """
932
+ target_language = target_language.lower()
933
+ if target_language not in LANGUAGES:
934
+ raise ValueError(
935
+ f"Invalid target language: {target_language}. "
936
+ "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
937
+ )
938
+ self._target_language = target_language
939
+ self._target_version = target_version
940
+ # Taking the first suffix as the default for output files
941
+ self._target_suffix = f".{LANGUAGES[target_language]['suffixes'][0]}"
942
+
943
+ @classmethod
944
+ def eval_obj(cls, target, metric_func, *args, **kwargs):
945
+ if "reference" in kwargs:
946
+ return cls.eval_obj_reference(target, metric_func, *args, **kwargs)
947
+ else:
948
+ return cls.eval_obj_noreference(target, metric_func, *args, **kwargs)
949
+
950
+ @classmethod
951
+ def eval_obj_noreference(cls, target, metric_func, *args, **kwargs):
952
+ results = []
953
+ for o in target["outputs"]:
954
+ if isinstance(o, dict):
955
+ results += cls.eval_obj_noreference(o, metric_func, *args, **kwargs)
956
+ else:
957
+ results.append(metric_func(o, *args, **kwargs))
958
+ return results
959
+
960
+ @classmethod
961
+ def eval_obj_reference(cls, target, metric_func, reference, *args, **kwargs):
962
+ results = []
963
+ for o, r in zip(target["outputs"], reference["outputs"]):
964
+ if isinstance(o, dict):
965
+ if not isinstance(r, dict):
966
+ raise ValueError("Error: format of reference doesn't match target")
967
+ results += cls.eval_obj_reference(o, metric_func, r, *args, **kwargs)
968
+ else:
969
+ if isinstance(r, dict):
970
+ raise ValueError("Error: format of reference doesn't match target")
971
+ results.append(metric_func(o, r, *args, **kwargs))
972
+ return results
@@ -14,6 +14,7 @@ class DiagramGenerator(Documenter):
14
14
  self,
15
15
  diagram_type="Activity",
16
16
  add_documentation=False,
17
+ extract_variables=False,
17
18
  **kwargs,
18
19
  ) -> None:
19
20
  """Initialize the DiagramGenerator class
@@ -28,24 +29,25 @@ class DiagramGenerator(Documenter):
28
29
  self._documenter = Documenter(**kwargs)
29
30
 
30
31
  super().__init__(**kwargs)
31
-
32
- self.set_prompt("diagram_with_documentation" if add_documentation else "diagram")
32
+ prompts = []
33
+ if extract_variables:
34
+ prompts.append("extract_variables")
35
+ prompts += ["diagram_with_documentation" if add_documentation else "diagram"]
36
+ self.set_prompts(prompts)
33
37
  self._parser = UMLSyntaxParser(language="plantuml")
34
38
 
35
39
  self._load_parameters()
36
40
 
37
- def _load_prompt(self):
38
- super()._load_prompt()
39
- self._prompt = self._prompt.partial(DIAGRAM_TYPE=self._diagram_type)
40
-
41
41
  def _input_runnable(self) -> Runnable:
42
42
  if self._add_documentation:
43
43
  return RunnableParallel(
44
44
  SOURCE_CODE=self._parser.parse_input,
45
45
  DOCUMENTATION=self._documenter.chain,
46
46
  context=self._retriever,
47
+ DIAGRAM_TYPE=lambda x: self._diagram_type,
47
48
  )
48
49
  return RunnableParallel(
49
50
  SOURCE_CODE=self._parser.parse_input,
50
51
  context=self._retriever,
52
+ DIAGRAM_TYPE=lambda x: self._diagram_type,
51
53
  )