janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
janus/__init__.py CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.5.3"
8
+ __version__ = "4.1.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/cli.py CHANGED
@@ -39,10 +39,12 @@ from janus.llm.models_info import (
39
39
  MODEL_TYPE_CONSTRUCTORS,
40
40
  MODEL_TYPES,
41
41
  TOKEN_LIMITS,
42
+ azure_models,
42
43
  bedrock_models,
43
44
  openai_models,
44
45
  )
45
46
  from janus.metrics.cli import evaluate
47
+ from janus.refiners.refiner import REFINERS
46
48
  from janus.utils.enums import LANGUAGES
47
49
  from janus.utils.logger import create_logger
48
50
 
@@ -242,6 +244,24 @@ def translate(
242
244
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
243
245
  ),
244
246
  ] = "file",
247
+ refiner_type: Annotated[
248
+ str,
249
+ typer.Option(
250
+ "-r",
251
+ "--refiner",
252
+ help="Name of custom refiner to use",
253
+ click_type=click.Choice(list(REFINERS.keys())),
254
+ ),
255
+ ] = "none",
256
+ retriever_type: Annotated[
257
+ str,
258
+ typer.Option(
259
+ "-R",
260
+ "--retriever",
261
+ help="Name of custom retriever to use",
262
+ click_type=click.Choice(["active_usings"]),
263
+ ),
264
+ ] = None,
245
265
  max_tokens: Annotated[
246
266
  int,
247
267
  typer.Option(
@@ -251,13 +271,6 @@ def translate(
251
271
  "If unspecificed, model's default max will be used.",
252
272
  ),
253
273
  ] = None,
254
- skip_refiner: Annotated[
255
- bool,
256
- typer.Option(
257
- "--skip-refiner",
258
- help="Whether to skip the refiner for generating output",
259
- ),
260
- ] = True,
261
274
  ):
262
275
  try:
263
276
  target_language, target_version = target_lang.split("-")
@@ -283,8 +296,8 @@ def translate(
283
296
  db_path=db_loc,
284
297
  db_config=collections_config,
285
298
  splitter_type=splitter_type,
286
- skip_context=skip_context,
287
- skip_refiner=skip_refiner,
299
+ refiner_type=refiner_type,
300
+ retriever_type=retriever_type,
288
301
  )
289
302
  translator.translate(input_dir, output_dir, overwrite, collection)
290
303
 
@@ -342,14 +355,6 @@ def document(
342
355
  help="Whether to overwrite existing files in the output directory",
343
356
  ),
344
357
  ] = False,
345
- skip_context: Annotated[
346
- bool,
347
- typer.Option(
348
- "--skip-context",
349
- help="Prompts will include any context information associated with source"
350
- " code blocks, unless this option is specified",
351
- ),
352
- ] = False,
353
358
  doc_mode: Annotated[
354
359
  str,
355
360
  typer.Option(
@@ -397,6 +402,24 @@ def document(
397
402
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
398
403
  ),
399
404
  ] = "file",
405
+ refiner_type: Annotated[
406
+ str,
407
+ typer.Option(
408
+ "-r",
409
+ "--refiner",
410
+ help="Name of custom refiner to use",
411
+ click_type=click.Choice(list(REFINERS.keys())),
412
+ ),
413
+ ] = "none",
414
+ retriever_type: Annotated[
415
+ str,
416
+ typer.Option(
417
+ "-R",
418
+ "--retriever",
419
+ help="Name of custom retriever to use",
420
+ click_type=click.Choice(["active_usings"]),
421
+ ),
422
+ ] = None,
400
423
  max_tokens: Annotated[
401
424
  int,
402
425
  typer.Option(
@@ -406,13 +429,6 @@ def document(
406
429
  "If unspecificed, model's default max will be used.",
407
430
  ),
408
431
  ] = None,
409
- skip_refiner: Annotated[
410
- bool,
411
- typer.Option(
412
- "--skip-refiner",
413
- help="Whether to skip the refiner for generating output",
414
- ),
415
- ] = True,
416
432
  ):
417
433
  model_arguments = dict(temperature=temperature)
418
434
  collections_config = get_collections_config()
@@ -425,8 +441,8 @@ def document(
425
441
  db_path=db_loc,
426
442
  db_config=collections_config,
427
443
  splitter_type=splitter_type,
428
- skip_refiner=skip_refiner,
429
- skip_context=skip_context,
444
+ refiner_type=refiner_type,
445
+ retriever_type=retriever_type,
430
446
  )
431
447
  if doc_mode == "madlibs":
432
448
  documenter = MadLibsDocumenter(
@@ -615,14 +631,6 @@ def diagram(
615
631
  help="Whether to overwrite existing files in the output directory",
616
632
  ),
617
633
  ] = False,
618
- skip_context: Annotated[
619
- bool,
620
- typer.Option(
621
- "--skip-context",
622
- help="Prompts will include any context information associated with source"
623
- " code blocks, unless this option is specified",
624
- ),
625
- ] = False,
626
634
  temperature: Annotated[
627
635
  float,
628
636
  typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
@@ -659,13 +667,24 @@ def diagram(
659
667
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
660
668
  ),
661
669
  ] = "file",
662
- skip_refiner: Annotated[
663
- bool,
670
+ refiner_type: Annotated[
671
+ str,
664
672
  typer.Option(
665
- "--skip-refiner",
666
- help="Whether to skip the refiner for generating output",
673
+ "-r",
674
+ "--refiner",
675
+ help="Name of custom refiner to use",
676
+ click_type=click.Choice(list(REFINERS.keys())),
667
677
  ),
668
- ] = True,
678
+ ] = "none",
679
+ retriever_type: Annotated[
680
+ str,
681
+ typer.Option(
682
+ "-R",
683
+ "--retriever",
684
+ help="Name of custom retriever to use",
685
+ click_type=click.Choice(["active_usings"]),
686
+ ),
687
+ ] = None,
669
688
  ):
670
689
  model_arguments = dict(temperature=temperature)
671
690
  collections_config = get_collections_config()
@@ -676,11 +695,11 @@ def diagram(
676
695
  max_prompts=max_prompts,
677
696
  db_path=db_loc,
678
697
  db_config=collections_config,
698
+ splitter_type=splitter_type,
699
+ refiner_type=refiner_type,
700
+ retriever_type=retriever_type,
679
701
  diagram_type=diagram_type,
680
702
  add_documentation=add_documentation,
681
- splitter_type=splitter_type,
682
- skip_refiner=skip_refiner,
683
- skip_context=skip_context,
684
703
  )
685
704
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
686
705
 
@@ -934,7 +953,7 @@ def llm_add(
934
953
  help="The type of the model",
935
954
  click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
936
955
  ),
937
- ] = "OpenAI",
956
+ ] = "Azure",
938
957
  ):
939
958
  if not MODEL_CONFIG_DIR.exists():
940
959
  MODEL_CONFIG_DIR.mkdir(parents=True)
@@ -978,6 +997,7 @@ def llm_add(
978
997
  "model_cost": {"input": in_cost, "output": out_cost},
979
998
  }
980
999
  elif model_type == "OpenAI":
1000
+ print("DEPRECATED: Use 'Azure' instead. CTRL+C to exit.")
981
1001
  model_id = typer.prompt(
982
1002
  "Enter the model ID (list model IDs with `janus llm ls -a`)",
983
1003
  default="gpt-4o",
@@ -999,6 +1019,28 @@ def llm_add(
999
1019
  "token_limit": max_tokens,
1000
1020
  "model_cost": model_cost,
1001
1021
  }
1022
+ elif model_type == "Azure":
1023
+ model_id = typer.prompt(
1024
+ "Enter the model ID (list model IDs with `janus llm ls -a`)",
1025
+ default="gpt-4o",
1026
+ type=click.Choice(azure_models),
1027
+ show_choices=False,
1028
+ )
1029
+ params = dict(
1030
+ # Azure uses the "azure_deployment" key for what we're calling "long_model_id"
1031
+ azure_deployment=MODEL_ID_TO_LONG_ID[model_id],
1032
+ temperature=0.7,
1033
+ n=1,
1034
+ )
1035
+ max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
1036
+ model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
1037
+ cfg = {
1038
+ "model_type": model_type,
1039
+ "model_id": model_id,
1040
+ "model_args": params,
1041
+ "token_limit": max_tokens,
1042
+ "model_cost": model_cost,
1043
+ }
1002
1044
  elif model_type == "BedrockChat" or model_type == "Bedrock":
1003
1045
  model_id = typer.prompt(
1004
1046
  "Enter the model ID (list model IDs with `janus llm ls -a`)",
@@ -1173,13 +1215,14 @@ def render(
1173
1215
  for input_file in input_dir.rglob("*.json"):
1174
1216
  with open(input_file, "r") as f:
1175
1217
  data = json.load(f)
1176
- input_tail = input_file.relative_to(input_dir)
1177
- output_file = output_dir / input_tail
1178
- output_file = output_file.with_suffix(".txt")
1218
+
1219
+ output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
1179
1220
  if not output_file.parent.exists():
1180
1221
  output_file.parent.mkdir()
1181
- with open(output_file, "w") as f:
1182
- f.write(data["output"])
1222
+
1223
+ text = data["output"].replace("\\n", "\n").strip()
1224
+ output_file.write_text(text)
1225
+
1183
1226
  jar_path = homedir / ".janus/lib/plantuml.jar"
1184
1227
  subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
1185
1228
  output_file.unlink()
@@ -90,14 +90,14 @@ class TestDiagramGenerator(unittest.TestCase):
90
90
  def setUp(self):
91
91
  """Set up the tests."""
92
92
  self.diagram_generator = DiagramGenerator(
93
- model="gpt-4o",
93
+ model="gpt-4o-mini",
94
94
  source_language="fortran",
95
95
  diagram_type="Activity",
96
96
  )
97
97
 
98
98
  def test_init(self):
99
99
  """Test __init__ method."""
100
- self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
100
+ self.assertEqual(self.diagram_generator._model_name, "gpt-4o-mini")
101
101
  self.assertEqual(self.diagram_generator._source_language, "fortran")
102
102
  self.assertEqual(self.diagram_generator._diagram_type, "Activity")
103
103
 
@@ -2,13 +2,11 @@ import functools
2
2
  import json
3
3
  import time
4
4
  from pathlib import Path
5
- from typing import Any, List, Optional, Tuple
5
+ from typing import Any
6
6
 
7
- from langchain.output_parsers import RetryWithErrorOutputParser
8
7
  from langchain_core.exceptions import OutputParserException
9
- from langchain_core.language_models import BaseLanguageModel
10
- from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
11
- from langchain_core.runnables import RunnableLambda, RunnableParallel
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
12
10
  from openai import BadRequestError, RateLimitError
13
11
  from pydantic import ValidationError
14
12
 
@@ -22,12 +20,18 @@ from janus.language.splitter import (
22
20
  Splitter,
23
21
  TokenLimitError,
24
22
  )
25
- from janus.llm import load_model
26
23
  from janus.llm.model_callbacks import get_model_callback
27
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
24
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
28
25
  from janus.parsers.parser import GenericParser, JanusParser
29
- from janus.parsers.refiner_parser import RefinerParser
30
- from janus.refiners.refiner import BasicRefiner, Refiner
26
+ from janus.refiners.refiner import (
27
+ FixParserExceptions,
28
+ HallucinationRefiner,
29
+ JanusRefiner,
30
+ ReflectionRefiner,
31
+ )
32
+
33
+ # from janus.refiners.refiner import BasicRefiner, Refiner
34
+ from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
31
35
  from janus.utils.enums import LANGUAGES
32
36
  from janus.utils.logger import create_logger
33
37
 
@@ -74,9 +78,8 @@ class Converter:
74
78
  protected_node_types: tuple[str, ...] = (),
75
79
  prune_node_types: tuple[str, ...] = (),
76
80
  splitter_type: str = "file",
77
- refiner_type: str = "basic",
78
- skip_refiner: bool = True,
79
- skip_context: bool = False,
81
+ refiner_type: str | None = None,
82
+ retriever_type: str | None = None,
80
83
  ) -> None:
81
84
  """Initialize a Converter instance.
82
85
 
@@ -96,9 +99,13 @@ class Converter:
96
99
  prune_node_types: A set of node types which should be pruned.
97
100
  splitter_type: The type of splitter to use. Valid values are `"file"`,
98
101
  `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
- refiner_type: The type of refiner to use. Valid values are `"basic"`.
100
- skip_refiner: Whether to skip the refiner.
101
- skip_context: Whether to skip adding context to the prompt.
102
+ refiner_type: The type of refiner to use. Valid values:
103
+ - "parser"
104
+ - "reflection"
105
+ - None
106
+ retriever_type: The type of retriever to use. Valid values:
107
+ - "active_usings"
108
+ - None
102
109
  """
103
110
  self._changed_attrs: set = set()
104
111
 
@@ -107,7 +114,6 @@ class Converter:
107
114
  self.override_token_limit: bool = max_tokens is not None
108
115
 
109
116
  self._model_name: str
110
- self._model_id: str
111
117
  self._custom_model_arguments: dict[str, Any]
112
118
 
113
119
  self._source_language: str
@@ -120,24 +126,26 @@ class Converter:
120
126
  self._prune_node_types: tuple[str, ...] = ()
121
127
  self._max_tokens: int | None = max_tokens
122
128
  self._prompt_template_name: str
123
- self._splitter_type: str
124
129
  self._db_path: str | None
125
130
  self._db_config: dict[str, Any] | None
126
131
 
127
- self._splitter: Splitter
128
- self._llm: BaseLanguageModel
132
+ self._llm: JanusModel
129
133
  self._prompt: ChatPromptTemplate
130
134
 
131
135
  self._parser: JanusParser = GenericParser()
132
136
  self._combiner: Combiner = Combiner()
133
137
 
134
- self._refiner_type: str
135
- self._refiner: Refiner
138
+ self._splitter_type: str
139
+ self._refiner_type: str | None
140
+ self._retriever_type: str | None
136
141
 
137
- self.skip_refiner = skip_refiner
142
+ self._splitter: Splitter
143
+ self._refiner: JanusRefiner
144
+ self._retriever: JanusRetriever
138
145
 
139
146
  self.set_splitter(splitter_type=splitter_type)
140
147
  self.set_refiner(refiner_type=refiner_type)
148
+ self.set_retriever(retriever_type=retriever_type)
141
149
  self.set_model(model_name=model, **model_arguments)
142
150
  self.set_prompt(prompt_template=prompt_template)
143
151
  self.set_source_language(source_language)
@@ -146,8 +154,6 @@ class Converter:
146
154
  self.set_db_path(db_path=db_path)
147
155
  self.set_db_config(db_config=db_config)
148
156
 
149
- self.skip_context = skip_context
150
-
151
157
  # Child class must call this. Should we enforce somehow?
152
158
  # self._load_parameters()
153
159
 
@@ -163,9 +169,11 @@ class Converter:
163
169
  def _load_parameters(self) -> None:
164
170
  self._load_model()
165
171
  self._load_prompt()
172
+ self._load_retriever()
173
+ self._load_refiner()
166
174
  self._load_splitter()
167
175
  self._load_vectorizer()
168
- self._load_refiner()
176
+ self._load_chain()
169
177
  self._changed_attrs.clear()
170
178
 
171
179
  def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
@@ -184,8 +192,6 @@ class Converter:
184
192
  def set_prompt(self, prompt_template: str) -> None:
185
193
  """Validate and set the prompt template name.
186
194
 
187
- The affected objects will not be updated until translate() is called.
188
-
189
195
  Arguments:
190
196
  prompt_template: name of prompt template directory
191
197
  (see janus/prompts/templates) or path to a directory.
@@ -195,29 +201,34 @@ class Converter:
195
201
  def set_splitter(self, splitter_type: str) -> None:
196
202
  """Validate and set the prompt template name.
197
203
 
198
- The affected objects will not be updated until translate() is called.
199
-
200
204
  Arguments:
201
205
  prompt_template: name of prompt template directory
202
206
  (see janus/prompts/templates) or path to a directory.
203
207
  """
204
- self._splitter_type = splitter_type
208
+ if splitter_type not in CUSTOM_SPLITTERS:
209
+ raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
205
210
 
206
- def set_refiner(self, refiner_type: str) -> None:
207
- """Validate and set the refiner name
211
+ self._splitter_type = splitter_type
208
212
 
209
- The affected objects will not be updated until translate is called
213
+ def set_refiner(self, refiner_type: str | None) -> None:
214
+ """Validate and set the refiner type
210
215
 
211
216
  Arguments:
212
- refiner_type: the name of the refiner to use
217
+ refiner_type: the type of refiner to use
213
218
  """
214
219
  self._refiner_type = refiner_type
215
220
 
221
+ def set_retriever(self, retriever_type: str | None) -> None:
222
+ """Validate and set the retriever type
223
+
224
+ Arguments:
225
+ retriever_type: the type of retriever to use
226
+ """
227
+ self._retriever_type = retriever_type
228
+
216
229
  def set_source_language(self, source_language: str) -> None:
217
230
  """Validate and set the source language.
218
231
 
219
- The affected objects will not be updated until _load_parameters() is called.
220
-
221
232
  Arguments:
222
233
  source_language: The source programming language.
223
234
  """
@@ -287,20 +298,6 @@ class Converter:
287
298
 
288
299
  self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
289
300
 
290
- @run_if_changed("_refiner_type", "_model_name")
291
- def _load_refiner(self) -> None:
292
- """Load the refiner according to this instance's attributes.
293
-
294
- If the relevant fields have not been changed since the last time this method was
295
- called, nothing happens.
296
- """
297
- if self._refiner_type == "basic":
298
- self._refiner = BasicRefiner(
299
- "basic_refinement", self._model_id, self._source_language
300
- )
301
- else:
302
- raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
303
-
304
301
  @run_if_changed("_model_name", "_custom_model_arguments")
305
302
  def _load_model(self) -> None:
306
303
  """Load the model according to this instance's attributes.
@@ -314,9 +311,9 @@ class Converter:
314
311
  # model_arguments.update(self._custom_model_arguments)
315
312
 
316
313
  # Load the model
317
- self._llm, self._model_id, token_limit, self.model_cost = load_model(
318
- self._model_name
319
- )
314
+ self._llm = load_model(self._model_name)
315
+ token_limit = self._llm.token_limit
316
+
320
317
  # Set the max_tokens to less than half the model's limit to allow for enough
321
318
  # tokens at output
322
319
  # Only modify max_tokens if it is not specified by user
@@ -335,7 +332,7 @@ class Converter:
335
332
  If the relevant fields have not been changed since the last time this
336
333
  method was called, nothing happens.
337
334
  """
338
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
335
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
339
336
  source_language=self._source_language,
340
337
  prompt_template=self._prompt_template_name,
341
338
  )
@@ -354,6 +351,59 @@ class Converter:
354
351
  self._db_path, self._db_config
355
352
  )
356
353
 
354
+ @run_if_changed("_retriever_type")
355
+ def _load_retriever(self):
356
+ if self._retriever_type == "active_usings":
357
+ self._retriever = ActiveUsingsRetriever()
358
+ else:
359
+ self._retriever = JanusRetriever()
360
+
361
+ @run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
362
+ def _load_refiner(self) -> None:
363
+ """Load the refiner according to this instance's attributes.
364
+
365
+ If the relevant fields have not been changed since the last time this method was
366
+ called, nothing happens.
367
+ """
368
+ if self._refiner_type == "parser":
369
+ self._refiner = FixParserExceptions(
370
+ llm=self._llm,
371
+ parser=self._parser,
372
+ max_retries=self.max_prompts,
373
+ )
374
+ elif self._refiner_type == "reflection":
375
+ self._refiner = ReflectionRefiner(
376
+ llm=self._llm,
377
+ parser=self._parser,
378
+ max_retries=self.max_prompts,
379
+ )
380
+ elif self._refiner_type == "hallucination":
381
+ self._refiner = HallucinationRefiner(
382
+ llm=self._llm,
383
+ parser=self._parser,
384
+ max_retries=self.max_prompts,
385
+ )
386
+ else:
387
+ self._refiner = JanusRefiner(parser=self._parser)
388
+
389
+ @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
390
+ def _load_chain(self):
391
+ self.chain = (
392
+ self._input_runnable()
393
+ | self._prompt
394
+ | RunnableParallel(
395
+ completion=self._llm,
396
+ prompt_value=RunnablePassthrough(),
397
+ )
398
+ | self._refiner.parse_runnable
399
+ )
400
+
401
+ def _input_runnable(self) -> Runnable:
402
+ return RunnableParallel(
403
+ SOURCE_CODE=self._parser.parse_input,
404
+ context=self._retriever,
405
+ )
406
+
357
407
  def translate(
358
408
  self,
359
409
  input_directory: str | Path,
@@ -598,110 +648,29 @@ class Converter:
598
648
  return root
599
649
 
600
650
  def _run_chain(self, block: TranslatedCodeBlock) -> str:
601
- """Run the model with three nested error fixing schemes.
602
- First, try to fix simple formatting errors by giving the model just
603
- the output and the parsing error. After a number of attempts, try
604
- giving the model the output, the parsing error, and the original
605
- input. Again check/retry this output to solve for formatting errors.
606
- If we still haven't succeeded after several attempts, the model may
607
- be getting thrown off by a bad initial output; start from scratch
608
- and try again.
609
-
610
- The number of tries for each layer of this scheme is roughly equal
611
- to the cube root of self.max_retries, so the total calls to the
612
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
613
- """
614
- input = self._parser.parse_input(block.original)
615
-
616
- # Retries with just the output and the error
617
- n1 = round(self.max_prompts ** (1 / 2))
618
-
619
- # Retries with the input, output, and error
620
- n2 = round(self.max_prompts // n1)
621
-
622
- if not self.skip_context:
623
- self._make_prompt_additions(block)
624
- if not self.skip_refiner: # Make replacements in the prompt
625
- refine_output = RefinerParser(
626
- parser=self._parser,
627
- initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
628
- refiner=self._refiner,
629
- max_retries=n1,
630
- llm=self._llm,
631
- )
632
- else:
633
- refine_output = RetryWithErrorOutputParser.from_llm(
634
- llm=self._llm,
635
- parser=self._parser,
636
- max_retries=n1,
637
- )
638
-
639
- completion_chain = self._prompt | self._llm
640
- chain = RunnableParallel(
641
- completion=completion_chain, prompt_value=self._prompt
642
- ) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
643
- for _ in range(n2):
644
- try:
645
- return chain.invoke({"SOURCE_CODE": input})
646
- except OutputParserException:
647
- pass
648
-
649
- raise OutputParserException(f"Failed to parse after {n1*n2} retries")
651
+ return self.chain.invoke(block.original)
650
652
 
651
653
  def _get_output_obj(
652
654
  self, block: TranslatedCodeBlock
653
- ) -> dict[str, int | float | str | dict[str, str]]:
655
+ ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
654
656
  output_str = self._parser.parse_combined_output(block.complete_text)
655
657
 
656
- output: str | dict[str, str]
658
+ output_obj: str | dict[str, str]
657
659
  try:
658
- output = json.loads(output_str)
660
+ output_obj = json.loads(output_str)
659
661
  except json.JSONDecodeError:
660
- output = output_str
662
+ output_obj = output_str
661
663
 
662
664
  return dict(
663
- input=block.original.text,
665
+ input=block.original.text or "",
664
666
  metadata=dict(
665
667
  retries=block.total_retries,
666
668
  cost=block.total_cost,
667
669
  processing_time=block.processing_time,
668
670
  ),
669
- output=output,
670
- )
671
-
672
- @staticmethod
673
- def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
674
- """Get a list of strings to append to the prompt.
675
-
676
- Arguments:
677
- block: The `TranslatedCodeBlock` to save to a file.
678
- """
679
- return [(key, item) for key, item in block.context_tags.items()]
680
-
681
- def _make_prompt_additions(self, block: CodeBlock):
682
- # Prepare the additional context to prepend
683
- additional_context = "".join(
684
- [
685
- f"{context_tag}: {context}\n"
686
- for context_tag, context in self._get_prompt_additions(block)
687
- ]
671
+ output=output_obj,
688
672
  )
689
673
 
690
- if not hasattr(self._prompt, "messages"):
691
- log.debug("Skipping additions to prompt, no messages found on prompt object!")
692
- return
693
-
694
- # Iterate through existing messages to find and update the system message
695
- for i, message in enumerate(self._prompt.messages):
696
- if isinstance(message, SystemMessagePromptTemplate):
697
- # Prepend the additional context to the system message
698
- updated_system_message = SystemMessagePromptTemplate.from_template(
699
- additional_context + message.prompt.template
700
- )
701
- # Directly modify the message in the list
702
- self._prompt.messages[i] = updated_system_message
703
- break # Assuming there's only one system message to update
704
-
705
674
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
706
675
  """Save a file to disk.
707
676