janus-llm 3.4.0__py3-none-any.whl → 3.4.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
janus/__init__.py CHANGED
@@ -5,7 +5,7 @@ from langchain_core._api.deprecation import LangChainDeprecationWarning
5
5
  from janus.converter.translate import Translator
6
6
  from janus.metrics import * # noqa: F403
7
7
 
8
- __version__ = "3.4.0"
8
+ __version__ = "3.4.1"
9
9
 
10
10
  # Ignoring a deprecation warning from langchain_core that I can't seem to hunt down
11
11
  warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
janus/cli.py CHANGED
@@ -248,6 +248,13 @@ def translate(
248
248
  "If unspecificed, model's default max will be used.",
249
249
  ),
250
250
  ] = None,
251
+ skip_refiner: Annotated[
252
+ bool,
253
+ typer.Option(
254
+ "--skip-refiner",
255
+ help="Whether to skip the refiner for generating output",
256
+ ),
257
+ ] = True,
251
258
  ):
252
259
  try:
253
260
  target_language, target_version = target_lang.split("-")
@@ -274,6 +281,7 @@ def translate(
274
281
  db_config=collections_config,
275
282
  splitter_type=splitter_type,
276
283
  skip_context=skip_context,
284
+ skip_refiner=skip_refiner,
277
285
  )
278
286
  translator.translate(input_dir, output_dir, overwrite, collection)
279
287
 
@@ -395,6 +403,13 @@ def document(
395
403
  "If unspecificed, model's default max will be used.",
396
404
  ),
397
405
  ] = None,
406
+ skip_refiner: Annotated[
407
+ bool,
408
+ typer.Option(
409
+ "--skip-refiner",
410
+ help="Whether to skip the refiner for generating output",
411
+ ),
412
+ ] = True,
398
413
  ):
399
414
  model_arguments = dict(temperature=temperature)
400
415
  collections_config = get_collections_config()
@@ -407,6 +422,7 @@ def document(
407
422
  db_path=db_loc,
408
423
  db_config=collections_config,
409
424
  splitter_type=splitter_type,
425
+ skip_refiner=skip_refiner,
410
426
  skip_context=skip_context,
411
427
  )
412
428
  if doc_mode == "madlibs":
@@ -520,6 +536,13 @@ def diagram(
520
536
  click_type=click.Choice(list(CUSTOM_SPLITTERS.keys())),
521
537
  ),
522
538
  ] = "file",
539
+ skip_refiner: Annotated[
540
+ bool,
541
+ typer.Option(
542
+ "--skip-refiner",
543
+ help="Whether to skip the refiner for generating output",
544
+ ),
545
+ ] = True,
523
546
  ):
524
547
  model_arguments = dict(temperature=temperature)
525
548
  collections_config = get_collections_config()
@@ -533,6 +556,7 @@ def diagram(
533
556
  diagram_type=diagram_type,
534
557
  add_documentation=add_documentation,
535
558
  splitter_type=splitter_type,
559
+ skip_refiner=skip_refiner,
536
560
  skip_context=skip_context,
537
561
  )
538
562
  diagram_generator.translate(input_dir, output_dir, overwrite, collection)
@@ -1,6 +1,5 @@
1
1
  import functools
2
2
  import json
3
- import math
4
3
  import time
5
4
  from pathlib import Path
6
5
  from typing import Any, List, Optional, Tuple
@@ -77,6 +76,7 @@ class Converter:
77
76
  prune_node_types: tuple[str, ...] = (),
78
77
  splitter_type: str = "file",
79
78
  refiner_type: str = "basic",
79
+ skip_refiner: bool = True,
80
80
  skip_context: bool = False,
81
81
  ) -> None:
82
82
  """Initialize a Converter instance.
@@ -98,6 +98,8 @@ class Converter:
98
98
  splitter_type: The type of splitter to use. Valid values are `"file"`,
99
99
  `"tag"`, `"chunk"`, `"ast-strict"`, and `"ast-flex"`.
100
100
  refiner_type: The type of refiner to use. Valid values are `"basic"`.
101
+ skip_refiner: Whether to skip the refiner.
102
+ skip_context: Whether to skip adding context to the prompt.
101
103
  """
102
104
  self._changed_attrs: set = set()
103
105
 
@@ -133,6 +135,8 @@ class Converter:
133
135
  self._refiner_type: str
134
136
  self._refiner: Refiner
135
137
 
138
+ self.skip_refiner = skip_refiner
139
+
136
140
  self.set_splitter(splitter_type=splitter_type)
137
141
  self.set_refiner(refiner_type=refiner_type)
138
142
  self.set_model(model_name=model, **model_arguments)
@@ -293,7 +297,7 @@ class Converter:
293
297
  """
294
298
  if self._refiner_type == "basic":
295
299
  self._refiner = BasicRefiner(
296
- "basic_refinement", self._model_name, self._source_language
300
+ "basic_refinement", self._model_id, self._source_language
297
301
  )
298
302
  else:
299
303
  raise ValueError(f"Error: unknown refiner type {self._refiner_type}")
@@ -598,40 +602,41 @@ class Converter:
598
602
  self._parser.set_reference(block.original)
599
603
 
600
604
  # Retries with just the output and the error
601
- n1 = round(self.max_prompts ** (1 / 3))
605
+ n1 = round(self.max_prompts ** (1 / 2))
602
606
 
603
607
  # Retries with the input, output, and error
604
- n2 = round((self.max_prompts // n1) ** (1 / 2))
608
+ n2 = round(self.max_prompts // n1)
605
609
 
606
610
  # Retries with just the input
607
- n3 = math.ceil(self.max_prompts / (n1 * n2))
608
- # Make replacements in the prompt
609
611
  if not self.skip_context:
610
612
  self._make_prompt_additions(block)
611
-
612
- refine_output = RefinerParser(
613
- parser=self._parser,
614
- initial_prompt=self._prompt.format(**{"SOURCE_CODE": block.original.text}),
615
- refiner=self._refiner,
616
- max_retries=n1,
617
- llm=self._llm,
618
- )
619
- retry = RetryWithErrorOutputParser.from_llm(
620
- llm=self._llm,
621
- parser=refine_output,
622
- max_retries=n2,
623
- )
613
+ if not self.skip_refiner: # Make replacements in the prompt
614
+ refine_output = RefinerParser(
615
+ parser=self._parser,
616
+ initial_prompt=self._prompt.format(
617
+ **{"SOURCE_CODE": block.original.text}
618
+ ),
619
+ refiner=self._refiner,
620
+ max_retries=n1,
621
+ llm=self._llm,
622
+ )
623
+ else:
624
+ refine_output = RetryWithErrorOutputParser.from_llm(
625
+ llm=self._llm,
626
+ parser=self._parser,
627
+ max_retries=n1,
628
+ )
624
629
  completion_chain = self._prompt | self._llm
625
630
  chain = RunnableParallel(
626
631
  completion=completion_chain, prompt_value=self._prompt
627
- ) | RunnableLambda(lambda x: retry.parse_with_prompt(**x))
628
- for _ in range(n3):
632
+ ) | RunnableLambda(lambda x: refine_output.parse_with_prompt(**x))
633
+ for _ in range(n2):
629
634
  try:
630
635
  return chain.invoke({"SOURCE_CODE": block.original.text})
631
636
  except OutputParserException:
632
637
  pass
633
638
 
634
- raise OutputParserException(f"Failed to parse after {n1*n2*n3} retries")
639
+ raise OutputParserException(f"Failed to parse after {n1*n2} retries")
635
640
 
636
641
  def _get_output_obj(
637
642
  self, block: TranslatedCodeBlock
@@ -40,7 +40,9 @@ class RefinerParser(BaseOutputParser):
40
40
  return self.parser.parse(text)
41
41
  except OutputParserException as oe:
42
42
  err = str(oe)
43
- new_prompt, prompt_arguments = self.refiner.refine(last_prompt, text, err)
43
+ new_prompt, prompt_arguments = self.refiner.refine(
44
+ self.initial_prompt, last_prompt, text, err
45
+ )
44
46
  new_chain = new_prompt | self.llm
45
47
  text = new_chain.invoke(prompt_arguments)
46
48
  last_prompt = new_prompt.format(**prompt_arguments)
janus/refiners/refiner.py CHANGED
@@ -5,7 +5,12 @@ from janus.llm.models_info import MODEL_PROMPT_ENGINES
5
5
 
6
6
  class Refiner:
7
7
  def refine(
8
- self, original_prompt: str, original_output: str, errors: str, **kwargs
8
+ self,
9
+ original_prompt: str,
10
+ previous_prompt: str,
11
+ previous_output: str,
12
+ errors: str,
13
+ **kwargs,
9
14
  ) -> tuple[ChatPromptTemplate, dict[str, str]]:
10
15
  """Creates a new prompt based on feedback from original results
11
16
 
@@ -24,22 +29,27 @@ class BasicRefiner(Refiner):
24
29
  def __init__(
25
30
  self,
26
31
  prompt_name: str,
27
- model_name: str,
32
+ model_id: str,
28
33
  source_language: str,
29
34
  ) -> None:
30
35
  """Basic refiner, asks llm to fix output of previous prompt given errors
31
36
 
32
37
  Arguments:
33
38
  prompt_name: refinement prompt name to use
34
- model_name: name of llm to use
39
+ model_id: ID of the llm to use. Found in models_info.py
35
40
  source_language: source_langauge to use
36
41
  """
37
42
  self._prompt_name = prompt_name
38
- self._model_name = model_name
43
+ self._model_id = model_id
39
44
  self._source_language = source_language
40
45
 
41
46
  def refine(
42
- self, original_prompt: str, original_output: str, errors: str, **kwargs
47
+ self,
48
+ original_prompt: str,
49
+ previous_prompt: str,
50
+ previous_output: str,
51
+ errors: str,
52
+ **kwargs,
43
53
  ) -> tuple[ChatPromptTemplate, dict[str, str]]:
44
54
  """Creates a new prompt based on feedback from original results
45
55
 
@@ -51,13 +61,13 @@ class BasicRefiner(Refiner):
51
61
  Returns:
52
62
  Tuple of new prompt and prompt arguments
53
63
  """
54
- prompt_engine = MODEL_PROMPT_ENGINES[self._model_name](
64
+ prompt_engine = MODEL_PROMPT_ENGINES[self._model_id](
55
65
  prompt_template=self._prompt_name,
56
66
  source_language=self._source_language,
57
67
  )
58
68
  prompt_arguments = {
59
69
  "ORIGINAL_PROMPT": original_prompt,
60
- "OUTPUT": original_output,
70
+ "OUTPUT": previous_output,
61
71
  "ERRORS": errors,
62
72
  }
63
73
  return prompt_engine.prompt, prompt_arguments
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 3.4.0
3
+ Version: 3.4.1
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0
@@ -1,13 +1,13 @@
1
- janus/__init__.py,sha256=DDs8SRs9v96slIS8XcqcLZrUAfuwXiT9vFa6OlwYHEY,361
1
+ janus/__init__.py,sha256=XGgjz8H0Qo2cAkx4nww-GdVfdiM8hrSJfv6mGsEwy1s,361
2
2
  janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
3
3
  janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
5
5
  janus/_tests/test_cli.py,sha256=oYJsUGWfpBJWEGRG5NGxdJedU5DU_m6fwJ7xEbJVYl0,4244
6
- janus/cli.py,sha256=Uvg6xPnJK7dJcDw1J58s3I7qwiNgPfj2qyXwpC6UGbI,32538
6
+ janus/cli.py,sha256=_92FvDV4qza0nSmyiXqacYxyo1gY6IPwD4gCm6kZfqI,33213
7
7
  janus/converter/__init__.py,sha256=U2EOMcCykiC0ZqhorNefOP_04hOF18qhYoPKrVp1Vrk,345
8
8
  janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  janus/converter/_tests/test_translate.py,sha256=yzcFEGc_z8QmBBBmC9dZnfL9tT8CD1rkpc8Hz44Jp4c,5631
10
- janus/converter/converter.py,sha256=tBUX9rsu7filfqm6B_knafQUfNb4x4ls8so7KbEq3gs,27160
10
+ janus/converter/converter.py,sha256=CgaNjQE6bz8qCpoYTEwuON40clOq-8r-lp2C173xS0E,27422
11
11
  janus/converter/diagram.py,sha256=5mo1H3Y1uIBPYdIsWz9kxluN5DNyuUMZrtcJmGF2Uw0,5335
12
12
  janus/converter/document.py,sha256=hsW512veNjFWbdl5WriuUdNmMEqZy8ktRvqn9rRmA6E,4566
13
13
  janus/converter/evaluate.py,sha256=APWQUY3gjAXqkJkPzvj0UA4wPK3Cv9QSJLM-YK9t-ng,476
@@ -87,12 +87,12 @@ janus/parsers/_tests/test_code_parser.py,sha256=RVgMmLvg8_57g0uJphfX-jZZsyBqOSuG
87
87
  janus/parsers/code_parser.py,sha256=SZBsYThG4iszKlu4fHoWrs-6cbJiUFjWv4cLSr5bzDM,1790
88
88
  janus/parsers/doc_parser.py,sha256=bJiOE5M7npUZur_1MWJ14C2HZl7-yXExqRXiC5ZBJvI,5679
89
89
  janus/parsers/eval_parser.py,sha256=L1Lu2aNimcqUshe0FQee_9Zqj1rrqyZPXCgEAS05VJ4,2740
90
- janus/parsers/refiner_parser.py,sha256=72tOEhpHwCZqHDb2T4aS5bPsiXN3pQXUk_oOPupa3Ts,1621
90
+ janus/parsers/refiner_parser.py,sha256=5zGoPZyttfRw3kXYDKHId2nhVvAcv1QsvpDFRpe-Few,1680
91
91
  janus/parsers/reqs_parser.py,sha256=6YzpF63rjuDPqpKWfYvtjpsluWQ-UboWlsKoGrGQogA,2380
92
92
  janus/parsers/uml.py,sha256=ZRyGY8YxvYibacTd-WZEAAaW3XjmvJhPJE3o29f71t8,1825
93
93
  janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
94
  janus/prompts/prompt.py,sha256=3796YXIzzIec9b0iUzd8VZlq-AdQbzq8qUGXLy4KH-0,10586
95
- janus/refiners/refiner.py,sha256=O4i5JaPEWH_ijmHunTKP4YzX_ZwZIyOIckn4Hmf1ZOI,2084
95
+ janus/refiners/refiner.py,sha256=GkV4oUSCrLAhyDJY2aY_Jt8PRF3sC6-bv58YbL2PaNk,2227
96
96
  janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
97
  janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
98
  janus/utils/_tests/test_logger.py,sha256=jkkvrCTKwsFCsZtmyuvc-WJ0rC7LJi2Z91sIe4IiKzA,2209
@@ -100,8 +100,8 @@ janus/utils/_tests/test_progress.py,sha256=Rs_u5PiGjP-L-o6C1fhwfE1ig8jYu9Xo9s4p8
100
100
  janus/utils/enums.py,sha256=AoilbdiYyMvY2Mp0AM4xlbLSELfut2XMwhIM1S_msP4,27610
101
101
  janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
102
102
  janus/utils/progress.py,sha256=PIpcQec7SrhsfqB25LHj2CDDkfm9umZx90d9LZnAx6k,1469
103
- janus_llm-3.4.0.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
104
- janus_llm-3.4.0.dist-info/METADATA,sha256=GmgJ5Oq3MkXOpYw9Vl8YFaLx9dUNcaFZs8GQDV1t8vc,4184
105
- janus_llm-3.4.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
106
- janus_llm-3.4.0.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
107
- janus_llm-3.4.0.dist-info/RECORD,,
103
+ janus_llm-3.4.1.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
104
+ janus_llm-3.4.1.dist-info/METADATA,sha256=lQxVgmdepmQHUey6HX2z0wTUEg6XqpaCu91kWpbeq_E,4184
105
+ janus_llm-3.4.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
106
+ janus_llm-3.4.1.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
107
+ janus_llm-3.4.1.dist-info/RECORD,,