janus-llm 3.5.3__py3-none-any.whl → 4.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/__init__.py CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.5.3"
8
+ __version__ = "4.1.0"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/cli.py CHANGED
@@ -39,10 +39,12 @@ from janus.llm.models_info import (
39
39
  MODEL_TYPE_CONSTRUCTORS,
40
40
  MODEL_TYPES,
41
41
  TOKEN_LIMITS,
42
+ azure_models,
42
43
  bedrock_models,
43
44
  openai_models,
44
45
  )
45
46
  from janus.metrics.cli import evaluate
47
+ from janus.refiners.refiner import REFINERS
46
48
  from janus.utils.enums import LANGUAGES
47
49
  from janus.utils.logger import create_logger
48
50
 
@@ -242,6 +244,24 @@ def translate(
242
244
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
243
245
  ),
244
246
  ] = "file",
247
+ refiner_type: Annotated[
248
+ str,
249
+ typer.Option(
250
+ "-r",
251
+ "--refiner",
252
+ help="Name of custom refiner to use",
253
+ click_type=click.Choice(list(REFINERS.keys())),
254
+ ),
255
+ ] = "none",
256
+ retriever_type: Annotated[
257
+ str,
258
+ typer.Option(
259
+ "-R",
260
+ "--retriever",
261
+ help="Name of custom retriever to use",
262
+ click_type=click.Choice(["active_usings"]),
263
+ ),
264
+ ] = None,
245
265
  max_tokens: Annotated[
246
266
  int,
247
267
  typer.Option(
@@ -251,13 +271,6 @@ def translate(
251
271
  "If unspecificed, model's default max will be used.",
252
272
  ),
253
273
  ] = None,
254
- skip_refiner: Annotated[
255
- bool,
256
- typer.Option(
257
- "--skip-refiner",
258
- help="Whether to skip the refiner for generating output",
259
- ),
260
- ] = True,
261
274
  ):
262
275
  try:
263
276
  target_language, target_version = target_lang.split("-")
@@ -283,8 +296,8 @@ def translate(
283
296
  db_path=db_loc,
284
297
  db_config=collections_config,
285
298
  splitter_type=splitter_type,
286
- skip_context=skip_context,
287
- skip_refiner=skip_refiner,
299
+ refiner_type=refiner_type,
300
+ retriever_type=retriever_type,
288
301
  )
289
302
  translator.translate(input_dir, output_dir, overwrite, collection)
290
303
 
@@ -342,14 +355,6 @@ def document(
342
355
  help="Whether to overwrite existing files in the output directory",
343
356
  ),
344
357
  ] = False,
345
- skip_context: Annotated[
346
- bool,
347
- typer.Option(
348
- "--skip-context",
349
- help="Prompts will include any context information associated with source"
350
- " code blocks, unless this option is specified",
351
- ),
352
- ] = False,
353
358
  doc_mode: Annotated[
354
359
  str,
355
360
  typer.Option(
@@ -397,6 +402,24 @@ def document(
397
402
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
398
403
  ),
399
404
  ] = "file",
405
+ refiner_type: Annotated[
406
+ str,
407
+ typer.Option(
408
+ "-r",
409
+ "--refiner",
410
+ help="Name of custom refiner to use",
411
+ click_type=click.Choice(list(REFINERS.keys())),
412
+ ),
413
+ ] = "none",
414
+ retriever_type: Annotated[
415
+ str,
416
+ typer.Option(
417
+ "-R",
418
+ "--retriever",
419
+ help="Name of custom retriever to use",
420
+ click_type=click.Choice(["active_usings"]),
421
+ ),
422
+ ] = None,
400
423
  max_tokens: Annotated[
401
424
  int,
402
425
  typer.Option(
@@ -406,13 +429,6 @@ def document(
406
429
  "If unspecificed, model's default max will be used.",
407
430
  ),
408
431
  ] = None,
409
- skip_refiner: Annotated[
410
- bool,
411
- typer.Option(
412
- "--skip-refiner",
413
- help="Whether to skip the refiner for generating output",
414
- ),
415
- ] = True,
416
432
  ):
417
433
  model_arguments = dict(temperature=temperature)
418
434
  collections_config = get_collections_config()
@@ -425,8 +441,8 @@ def document(
425
441
  db_path=db_loc,
426
442
  db_config=collections_config,
427
443
  splitter_type=splitter_type,
428
- skip_refiner=skip_refiner,
429
- skip_context=skip_context,
444
+ refiner_type=refiner_type,
445
+ retriever_type=retriever_type,
430
446
  )
431
447
  if doc_mode == "madlibs":
432
448
  documenter = MadLibsDocumenter(
@@ -615,14 +631,6 @@ def diagram(
615
631
  help="Whether to overwrite existing files in the output directory",
616
632
  ),
617
633
  ] = False,
618
- skip_context: Annotated[
619
- bool,
620
- typer.Option(
621
- "--skip-context",
622
- help="Prompts will include any context information associated with source"
623
- " code blocks, unless this option is specified",
624
- ),
625
- ] = False,
626
634
  temperature: Annotated[
627
635
  float,
628
636
  typer.Option("--temperature", "-t", help="Sampling temperature.", min=0, max=2),
@@ -659,13 +667,24 @@ def diagram(
659
667
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
660
668
  ),
661
669
  ] = "file",
662
- skip_refiner: Annotated[
663
- bool,
670
+ refiner_type: Annotated[
671
+ str,
664
672
  typer.Option(
665
- "--skip-refiner",
666
- help="Whether to skip the refiner for generating output",
673
+ "-r",
674
+ "--refiner",
675
+ help="Name of custom refiner to use",
676
+ click_type=click.Choice(list(REFINERS.keys())),
667
677
  ),
668
- ] = True,
678
+ ] = "none",
679
+ retriever_type: Annotated[
680
+ str,
681
+ typer.Option(
682
+ "-R",
683
+ "--retriever",
684
+ help="Name of custom retriever to use",
685
+ click_type=click.Choice(["active_usings"]),
686
+ ),
687
+ ] = None,
669
688
  ):
670
689
  model_arguments = dict(temperature=temperature)
671
690
  collections_config = get_collections_config()
@@ -676,11 +695,11 @@ def diagram(
676
695
  max_prompts=max_prompts,
677
696
  db_path=db_loc,
678
697
  db_config=collections_config,
698
+ splitter_type=splitter_type,
699
+ refiner_type=refiner_type,
700
+ retriever_type=retriever_type,
679
701
  diagram_type=diagram_type,
680
702
  add_documentation=add_documentation,
681
- splitter_type=splitter_type,
682
- skip_refiner=skip_refiner,
683
- skip_context=skip_context,
684
703
  )
685
704
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
686
705
 
@@ -934,7 +953,7 @@ def llm_add(
934
953
  help="The type of the model",
935
954
  click_type=click.Choice(sorted(list(MODEL_TYPE_CONSTRUCTORS.keys()))),
936
955
  ),
937
- ] = "OpenAI",
956
+ ] = "Azure",
938
957
  ):
939
958
  if not MODEL_CONFIG_DIR.exists():
940
959
  MODEL_CONFIG_DIR.mkdir(parents=True)
@@ -978,6 +997,7 @@ def llm_add(
978
997
  "model_cost": {"input": in_cost, "output": out_cost},
979
998
  }
980
999
  elif model_type == "OpenAI":
1000
+ print("DEPRECATED: Use 'Azure' instead. CTRL+C to exit.")
981
1001
  model_id = typer.prompt(
982
1002
  "Enter the model ID (list model IDs with `janus llm ls -a`)",
983
1003
  default="gpt-4o",
@@ -999,6 +1019,28 @@ def llm_add(
999
1019
  "token_limit": max_tokens,
1000
1020
  "model_cost": model_cost,
1001
1021
  }
1022
+ elif model_type == "Azure":
1023
+ model_id = typer.prompt(
1024
+ "Enter the model ID (list model IDs with `janus llm ls -a`)",
1025
+ default="gpt-4o",
1026
+ type=click.Choice(azure_models),
1027
+ show_choices=False,
1028
+ )
1029
+ params = dict(
1030
+ # Azure uses the "azure_deployment" key for what we're calling "long_model_id"
1031
+ azure_deployment=MODEL_ID_TO_LONG_ID[model_id],
1032
+ temperature=0.7,
1033
+ n=1,
1034
+ )
1035
+ max_tokens = TOKEN_LIMITS[MODEL_ID_TO_LONG_ID[model_id]]
1036
+ model_cost = COST_PER_1K_TOKENS[MODEL_ID_TO_LONG_ID[model_id]]
1037
+ cfg = {
1038
+ "model_type": model_type,
1039
+ "model_id": model_id,
1040
+ "model_args": params,
1041
+ "token_limit": max_tokens,
1042
+ "model_cost": model_cost,
1043
+ }
1002
1044
  elif model_type == "BedrockChat" or model_type == "Bedrock":
1003
1045
  model_id = typer.prompt(
1004
1046
  "Enter the model ID (list model IDs with `janus llm ls -a`)",
@@ -1173,13 +1215,14 @@ def render(
1173
1215
  for input_file in input_dir.rglob("*.json"):
1174
1216
  with open(input_file, "r") as f:
1175
1217
  data = json.load(f)
1176
- input_tail = input_file.relative_to(input_dir)
1177
- output_file = output_dir / input_tail
1178
- output_file = output_file.with_suffix(".txt")
1218
+
1219
+ output_file = output_dir / input_file.relative_to(input_dir).with_suffix(".txt")
1179
1220
  if not output_file.parent.exists():
1180
1221
  output_file.parent.mkdir()
1181
- with open(output_file, "w") as f:
1182
- f.write(data["output"])
1222
+
1223
+ text = data["output"].replace("\\n", "\n").strip()
1224
+ output_file.write_text(text)
1225
+
1183
1226
  jar_path = homedir / ".janus/lib/plantuml.jar"
1184
1227
  subprocess.run(["java", "-jar", jar_path, output_file]) # nosec
1185
1228
  output_file.unlink()
@@ -90,14 +90,14 @@ class TestDiagramGenerator(unittest.TestCase):
90
90
  def setUp(self):
91
91
  """Set up the tests."""
92
92
  self.diagram_generator = DiagramGenerator(
93
- model="gpt-4o",
93
+ model="gpt-4o-mini",
94
94
  source_language="fortran",
95
95
  diagram_type="Activity",
96
96
  )
97
97
 
98
98
  def test_init(self):
99
99
  """Test __init__ method."""
100
- self.assertEqual(self.diagram_generator._model_name, "gpt-4o")
100
+ self.assertEqual(self.diagram_generator._model_name, "gpt-4o-mini")
101
101
  self.assertEqual(self.diagram_generator._source_language, "fortran")
102
102
  self.assertEqual(self.diagram_generator._diagram_type, "Activity")
103
103
 
@@ -2,13 +2,11 @@ import functools
2
2
  import json
3
3
  import time
4
4
  from pathlib import Path
5
- from typing import Any, List, Optional, Tuple
5
+ from typing import Any
6
6
 
7
- from langchain.output_parsers import RetryWithErrorOutputParser
8
7
  from langchain_core.exceptions import OutputParserException
9
- from langchain_core.language_models import BaseLanguageModel
10
- from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
11
- from langchain_core.runnables import RunnableLambda, RunnableParallel
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
12
10
  from openai import BadRequestError, RateLimitError
13
11
  from pydantic import ValidationError
14
12
 
@@ -22,12 +20,18 @@ from janus.language.splitter import (
22
20
  Splitter,
23
21
  TokenLimitError,
24
22
  )
25
- from janus.llm import load_model
26
23
  from janus.llm.model_callbacks import get_model_callback
27
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
24
+ from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel, load_model
28
25
  from janus.parsers.parser import GenericParser, JanusParser
29
- from janus.parsers.refiner_parser import RefinerParser
30
- from janus.refiners.refiner import BasicRefiner, Refiner
26
+ from janus.refiners.refiner import (
27
+ FixParserExceptions,
28
+ HallucinationRefiner,
29
+ JanusRefiner,
30
+ ReflectionRefiner,
31
+ )
32
+
33
+ # from janus.refiners.refiner import BasicRefiner, Refiner
34
+ from janus.retrievers.retriever import ActiveUsingsRetriever, JanusRetriever
31
35
  from janus.utils.enums import LANGUAGES
32
36
  from janus.utils.logger import create_logger
33
37
 
@@ -74,9 +78,8 @@ class Converter:
74
78
  protected_node_types: tuple[str, ...] = (),
75
79
  prune_node_types: tuple[str, ...] = (),
76
80
  splitter_type: str = "file",
77
- refiner_type: str = "basic",
78
- skip_refiner: bool = True,
79
- skip_context: bool = False,
81
+ refiner_type: str | None = None,
82
+ retriever_type: str | None = None,
80
83
  ) -> None:
81
84
  """Initialize a Converter instance.
82
85
 
@@ -96,9 +99,13 @@ class Converter:
96
99
  prune_node_types: A set of node types which should be pruned.
97
100
  splitter_type: The type of splitter to use. Valid values are `"file"`,
98
101
  `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
99
- refiner_type: The type of refiner to use. Valid values are `"basic"`.
100
- skip_refiner: Whether to skip the refiner.
101
- skip_context: Whether to skip adding context to the prompt.
102
+ refiner_type: The type of refiner to use. Valid values:
103
+ - "parser"
104
+ - "reflection"
105
+ - None
106
+ retriever_type: The type of retriever to use. Valid values:
107
+ - "active_usings"
108
+ - None
102
109
  """
103
110
  self._changed_attrs: set = set()
104
111
 
@@ -107,7 +114,6 @@ class Converter:
107
114
  self.override_token_limit: bool = max_tokens is not None
108
115
 
109
116
  self._model_name: str
110
- self._model_id: str
111
117
  self._custom_model_arguments: dict[str, Any]
112
118
 
113
119
  self._source_language: str
@@ -120,24 +126,26 @@ class Converter:
120
126
  self._prune_node_types: tuple[str, ...] = ()
121
127
  self._max_tokens: int | None = max_tokens
122
128
  self._prompt_template_name: str
123
- self._splitter_type: str
124
129
  self._db_path: str | None
125
130
  self._db_config: dict[str, Any] | None
126
131
 
127
- self._splitter: Splitter
128
- self._llm: BaseLanguageModel
132
+ self._llm: JanusModel
129
133
  self._prompt: ChatPromptTemplate
130
134
 
131
135
  self._parser: JanusParser = GenericParser()
132
136
  self._combiner: Combiner = Combiner()
133
137
 
134
- self._refiner_type: str
135
- self._refiner: Refiner
138
+ self._splitter_type: str
139
+ self._refiner_type: str | None
140
+ self._retriever_type: str | None
136
141
 
137
- self.skip_refiner = skip_refiner
142
+ self._splitter: Splitter
143
+ self._refiner: JanusRefiner
144
+ self._retriever: JanusRetriever
138
145
 
139
146
  self.set_splitter(splitter_type=splitter_type)
140
147
  self.set_refiner(refiner_type=refiner_type)
148
+ self.set_retriever(retriever_type=retriever_type)
141
149
  self.set_model(model_name=model, **model_arguments)
142
150
  self.set_prompt(prompt_template=prompt_template)
143
151
  self.set_source_language(source_language)
@@ -146,8 +154,6 @@ class Converter:
146
154
  self.set_db_path(db_path=db_path)
147
155
  self.set_db_config(db_config=db_config)
148
156
 
149
- self.skip_context = skip_context
150
-
151
157
  # Child class must call this. Should we enforce somehow?
152
158
  # self._load_parameters()
153
159
 
@@ -163,9 +169,11 @@ class Converter:
163
169
  def _load_parameters(self) -> None:
164
170
  self._load_model()
165
171
  self._load_prompt()
172
+ self._load_retriever()
173
+ self._load_refiner()
166
174
  self._load_splitter()
167
175
  self._load_vectorizer()
168
- self._load_refiner()
176
+ self._load_chain()
169
177
  self._changed_attrs.clear()
170
178
 
171
179
  def set_model(self, model_name: str, **custom_arguments: dict[str, Any]):
@@ -184,8 +192,6 @@ class Converter:
184
192
  def set_prompt(self, prompt_template: str) -> None:
185
193
  """Validate and set the prompt template name.
186
194
 
187
- The affected objects will not be updated until translate() is called.
188
-
189
195
  Arguments:
190
196
  prompt_template: name of prompt template directory
191
197
  (see janus/prompts/templates) or path to a directory.
@@ -195,29 +201,34 @@ class Converter:
195
201
  def set_splitter(self, splitter_type: str) -> None:
196
202
  """Validate and set the prompt template name.
197
203
 
198
- The affected objects will not be updated until translate() is called.
199
-
200
204
  Arguments:
201
205
  prompt_template: name of prompt template directory
202
206
  (see janus/prompts/templates) or path to a directory.
203
207
  """
204
- self._splitter_type = splitter_type
208
+ if splitter_type not in CUSTOM_SPLITTERS:
209
+ raise ValueError(f'Splitter type "{splitter_type}" does not exist.')
205
210
 
206
- def set_refiner(self, refiner_type: str) -> None:
207
- """Validate and set the refiner name
211
+ self._splitter_type = splitter_type
208
212
 
209
- The affected objects will not be updated until translate is called
213
+ def set_refiner(self, refiner_type: str | None) -> None:
214
+ """Validate and set the refiner type
210
215
 
211
216
  Arguments:
212
- refiner_type: the name of the refiner to use
217
+ refiner_type: the type of refiner to use
213
218
  """
214
219
  self._refiner_type = refiner_type
215
220
 
221
+ def set_retriever(self, retriever_type: str | None) -> None:
222
+ """Validate and set the retriever type
223
+
224
+ Arguments:
225
+ retriever_type: the type of retriever to use
226
+ """
227
+ self._retriever_type = retriever_type
228
+
216
229
  def set_source_language(self, source_language: str) -> None:
217
230
  """Validate and set the source language.
218
231
 
219
- The affected objects will not be updated until _load_parameters() is called.
220
-
221
232
  Arguments:
222
233
  source_language: The source programming language.
223
234
  """
@@ -287,20 +298,6 @@ class Converter:
287
298
 
288
299
  self._splitter = CUSTOM_SPLITTERS[self._splitter_type](**kwargs)
289
300
 
290
- @run_if_changed("_refiner_type", "_model_name")
291
- def _load_refiner(self) -> None:
292
- """Load the refiner according to this instance's attributes.
293
-
294
- If the relevant fields have not been changed since the last time this method was
295
- called, nothing happens.
296
- """
297
- if self._refiner_type == "basic":
298
- self._refiner = BasicRefiner(
299
- "basic_refinement", self._model_id, self._source_language
300
- )
301
- else:
302
- raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
303
-
304
301
  @run_if_changed("_model_name", "_custom_model_arguments")
305
302
  def _load_model(self) -> None:
306
303
  """Load the model according to this instance's attributes.
@@ -314,9 +311,9 @@ class Converter:
314
311
  # model_arguments.update(self._custom_model_arguments)
315
312
 
316
313
  # Load the model
317
- self._llm, self._model_id, token_limit, self.model_cost = load_model(
318
- self._model_name
319
- )
314
+ self._llm = load_model(self._model_name)
315
+ token_limit = self._llm.token_limit
316
+
320
317
  # Set the max_tokens to less than half the model's limit to allow for enough
321
318
  # tokens at output
322
319
  # Only modify max_tokens if it is not specified by user
@@ -335,7 +332,7 @@ class Converter:
335
332
  If the relevant fields have not been changed since the last time this
336
333
  method was called, nothing happens.
337
334
  """
338
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
335
+ prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
339
336
  source_language=self._source_language,
340
337
  prompt_template=self._prompt_template_name,
341
338
  )
@@ -354,6 +351,59 @@ class Converter:
354
351
  self._db_path, self._db_config
355
352
  )
356
353
 
354
+ @run_if_changed("_retriever_type")
355
+ def _load_retriever(self):
356
+ if self._retriever_type == "active_usings":
357
+ self._retriever = ActiveUsingsRetriever()
358
+ else:
359
+ self._retriever = JanusRetriever()
360
+
361
+ @run_if_changed("_refiner_type", "_model_name", "max_prompts", "_parser", "_llm")
362
+ def _load_refiner(self) -> None:
363
+ """Load the refiner according to this instance's attributes.
364
+
365
+ If the relevant fields have not been changed since the last time this method was
366
+ called, nothing happens.
367
+ """
368
+ if self._refiner_type == "parser":
369
+ self._refiner = FixParserExceptions(
370
+ llm=self._llm,
371
+ parser=self._parser,
372
+ max_retries=self.max_prompts,
373
+ )
374
+ elif self._refiner_type == "reflection":
375
+ self._refiner = ReflectionRefiner(
376
+ llm=self._llm,
377
+ parser=self._parser,
378
+ max_retries=self.max_prompts,
379
+ )
380
+ elif self._refiner_type == "hallucination":
381
+ self._refiner = HallucinationRefiner(
382
+ llm=self._llm,
383
+ parser=self._parser,
384
+ max_retries=self.max_prompts,
385
+ )
386
+ else:
387
+ self._refiner = JanusRefiner(parser=self._parser)
388
+
389
+ @run_if_changed("_parser", "_retriever", "_prompt", "_llm", "_refiner")
390
+ def _load_chain(self):
391
+ self.chain = (
392
+ self._input_runnable()
393
+ | self._prompt
394
+ | RunnableParallel(
395
+ completion=self._llm,
396
+ prompt_value=RunnablePassthrough(),
397
+ )
398
+ | self._refiner.parse_runnable
399
+ )
400
+
401
+ def _input_runnable(self) -> Runnable:
402
+ return RunnableParallel(
403
+ SOURCE_CODE=self._parser.parse_input,
404
+ context=self._retriever,
405
+ )
406
+
357
407
  def translate(
358
408
  self,
359
409
  input_directory: str | Path,
@@ -598,110 +648,29 @@ class Converter:
598
648
  return root
599
649
 
600
650
  def _run_chain(self, block: TranslatedCodeBlock) -> str:
601
- """Run the model with three nested error fixing schemes.
602
- First, try to fix simple formatting errors by giving the model just
603
- the output and the parsing error. After a number of attempts, try
604
- giving the model the output, the parsing error, and the original
605
- input. Again check/retry this output to solve for formatting errors.
606
- If we still haven't succeeded after several attempts, the model may
607
- be getting thrown off by a bad initial output; start from scratch
608
- and try again.
609
-
610
- The number of tries for each layer of this scheme is roughly equal
611
- to the cube root of self.max_retries, so the total calls to the
612
- LLM will be roughly as expected (up to sqrt(self.max_retries) over)
613
- """
614
- input = self._parser.parse_input(block.original)
615
-
616
- # Retries with just the output and the error
617
- n1 = round(self.max_prompts ** (1 / 2))
618
-
619
- # Retries with the input, output, and error
620
- n2 = round(self.max_prompts // n1)
621
-
622
- if not self.skip_context:
623
- self._make_prompt_additions(block)
624
- if not self.skip_refiner: # Make replacements in the prompt
625
- refine_output = RefinerParser(
626
- parser=self._parser,
627
- initial_prompt=self._prompt.format(**{"SOURCE_CODE": input}),
628
- refiner=self._refiner,
629
- max_retries=n1,
630
- llm=self._llm,
631
- )
632
- else:
633
- refine_output = RetryWithErrorOutputParser.from_llm(
634
- llm=self._llm,
635
- parser=self._parser,
636
- max_retries=n1,
637
- )
638
-
639
- completion_chain = self._prompt | self._llm
640
- chain = RunnableParallel(
641
- completion=completion_chain, prompt_value=self._prompt
642
- ) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
643
- for _ in range(n2):
644
- try:
645
- return chain.invoke({"SOURCE_CODE": input})
646
- except OutputParserException:
647
- pass
648
-
649
- raise OutputParserException(f"Failed to parse after {n1*n2} retries")
651
+ return self.chain.invoke(block.original)
650
652
 
651
653
  def _get_output_obj(
652
654
  self, block: TranslatedCodeBlock
653
- ) -> dict[str, int | float | str | dict[str, str]]:
655
+ ) -> dict[str, int | float | str | dict[str, str] | dict[str, float]]:
654
656
  output_str = self._parser.parse_combined_output(block.complete_text)
655
657
 
656
- output: str | dict[str, str]
658
+ output_obj: str | dict[str, str]
657
659
  try:
658
- output = json.loads(output_str)
660
+ output_obj = json.loads(output_str)
659
661
  except json.JSONDecodeError:
660
- output = output_str
662
+ output_obj = output_str
661
663
 
662
664
  return dict(
663
- input=block.original.text,
665
+ input=block.original.text or "",
664
666
  metadata=dict(
665
667
  retries=block.total_retries,
666
668
  cost=block.total_cost,
667
669
  processing_time=block.processing_time,
668
670
  ),
669
- output=output,
670
- )
671
-
672
- @staticmethod
673
- def _get_prompt_additions(block) -> Optional[List[Tuple[str, str]]]:
674
- """Get a list of strings to append to the prompt.
675
-
676
- Arguments:
677
- block: The `TranslatedCodeBlock` to save to a file.
678
- """
679
- return [(key, item) for key, item in block.context_tags.items()]
680
-
681
- def _make_prompt_additions(self, block: CodeBlock):
682
- # Prepare the additional context to prepend
683
- additional_context = "".join(
684
- [
685
- f"{context_tag}: {context}\n"
686
- for context_tag, context in self._get_prompt_additions(block)
687
- ]
671
+ output=output_obj,
688
672
  )
689
673
 
690
- if not hasattr(self._prompt, "messages"):
691
- log.debug("Skipping additions to prompt, no messages found on prompt object!")
692
- return
693
-
694
- # Iterate through existing messages to find and update the system message
695
- for i, message in enumerate(self._prompt.messages):
696
- if isinstance(message, SystemMessagePromptTemplate):
697
- # Prepend the additional context to the system message
698
- updated_system_message = SystemMessagePromptTemplate.from_template(
699
- additional_context + message.prompt.template
700
- )
701
- # Directly modify the message in the list
702
- self._prompt.messages[i] = updated_system_message
703
- break # Assuming there's only one system message to update
704
-
705
674
  def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
706
675
  """Save a file to disk.
707
676