janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. janus/__init__.py +2 -2
  2. janus/__main__.py +1 -1
  3. janus/_tests/test_cli.py +1 -2
  4. janus/cli.py +43 -51
  5. janus/converter/__init__.py +6 -0
  6. janus/converter/_tests/__init__.py +0 -0
  7. janus/{_tests → converter/_tests}/test_translate.py +11 -22
  8. janus/converter/converter.py +614 -0
  9. janus/converter/diagram.py +124 -0
  10. janus/converter/document.py +131 -0
  11. janus/converter/evaluate.py +15 -0
  12. janus/converter/requirements.py +50 -0
  13. janus/converter/translate.py +108 -0
  14. janus/embedding/_tests/test_collections.py +2 -2
  15. janus/language/_tests/test_splitter.py +1 -1
  16. janus/language/alc/__init__.py +1 -0
  17. janus/language/alc/_tests/__init__.py +0 -0
  18. janus/language/alc/_tests/test_alc.py +28 -0
  19. janus/language/alc/alc.py +87 -0
  20. janus/language/block.py +4 -2
  21. janus/language/combine.py +0 -1
  22. janus/language/mumps/mumps.py +2 -3
  23. janus/language/naive/__init__.py +1 -1
  24. janus/language/naive/basic_splitter.py +4 -4
  25. janus/language/naive/chunk_splitter.py +4 -4
  26. janus/language/naive/registry.py +1 -1
  27. janus/language/naive/simple_ast.py +23 -12
  28. janus/language/naive/tag_splitter.py +4 -4
  29. janus/language/splitter.py +10 -4
  30. janus/language/treesitter/treesitter.py +26 -8
  31. janus/llm/model_callbacks.py +34 -37
  32. janus/llm/models_info.py +16 -3
  33. janus/metrics/_tests/test_llm.py +2 -3
  34. janus/metrics/_tests/test_rouge_score.py +1 -1
  35. janus/metrics/_tests/test_similarity_score.py +1 -1
  36. janus/metrics/complexity_metrics.py +3 -4
  37. janus/metrics/metric.py +3 -4
  38. janus/metrics/reading.py +27 -5
  39. janus/prompts/prompt.py +67 -7
  40. janus/utils/enums.py +6 -5
  41. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
  42. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
  43. janus/converter.py +0 -158
  44. janus/translate.py +0 -981
  45. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
  46. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
  47. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
@@ -1,18 +1,29 @@
1
- from janus.language.naive.registry import register_splitter
2
- from janus.language.treesitter import TreeSitterSplitter
3
- from janus.utils.enums import LANGUAGES
1
+ from ...utils.enums import LANGUAGES
2
+ from ..alc.alc import AlcSplitter
3
+ from ..mumps.mumps import MumpsSplitter
4
+ from ..treesitter import TreeSitterSplitter
5
+ from .registry import register_splitter
4
6
 
5
7
 
6
8
  @register_splitter("ast-flex")
7
- class FlexibleTreeSitterSplitter(TreeSitterSplitter):
8
- pass
9
+ def get_flexible_ast(language: str, **kwargs):
10
+ if language == "ibmhlasm":
11
+ return AlcSplitter(**kwargs)
12
+ elif language == "mumps":
13
+ return MumpsSplitter(**kwargs)
14
+ else:
15
+ return TreeSitterSplitter(language=language, **kwargs)
9
16
 
10
17
 
11
18
  @register_splitter("ast-strict")
12
- class StrictTreeSitterSplitter(TreeSitterSplitter):
13
- def __init__(self, language: str, **kwargs):
14
- kwargs.update(
15
- protected_node_types=(LANGUAGES[language]["functional_node_type"],),
16
- prune_unprotected=True,
17
- )
18
- super().__init__(language=language, **kwargs)
19
+ def get_strict_ast(language: str, **kwargs):
20
+ kwargs.update(
21
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
22
+ prune_unprotected=True,
23
+ )
24
+ if language == "ibmhlasm":
25
+ return AlcSplitter(**kwargs)
26
+ elif language == "mumps":
27
+ return MumpsSplitter(**kwargs)
28
+ else:
29
+ return TreeSitterSplitter(language=language, **kwargs)
@@ -1,7 +1,7 @@
1
- from janus.language.block import CodeBlock
2
- from janus.language.naive.registry import register_splitter
3
- from janus.language.node import NodeType
4
- from janus.language.splitter import Splitter
1
+ from ..block import CodeBlock
2
+ from ..node import NodeType
3
+ from ..splitter import Splitter
4
+ from .registry import register_splitter
5
5
 
6
6
 
7
7
  @register_splitter("tag")
@@ -47,8 +47,8 @@ class Splitter(FileManager):
47
47
  model: None | BaseLanguageModel = None,
48
48
  max_tokens: int = 4096,
49
49
  skip_merge: bool = False,
50
- protected_node_types: tuple[str] = (),
51
- prune_node_types: tuple[str] = (),
50
+ protected_node_types: tuple[str, ...] = (),
51
+ prune_node_types: tuple[str, ...] = (),
52
52
  prune_unprotected: bool = False,
53
53
  ):
54
54
  """
@@ -340,7 +340,10 @@ class Splitter(FileManager):
340
340
  # Double check length (in theory this should never be an issue)
341
341
  tokens = self._count_tokens(text)
342
342
  if tokens > self.max_tokens:
343
- log.error(f"Merged node ({name}) too long for context!")
343
+ log.error(
344
+ f"Merged node ({name}) too long for context!"
345
+ f" ({tokens} > {self.max_tokens})"
346
+ )
344
347
 
345
348
  return CodeBlock(
346
349
  text=text,
@@ -420,7 +423,10 @@ class Splitter(FileManager):
420
423
  name = f"{node.name}-L#{node_line}"
421
424
  tokens = self._count_tokens(line)
422
425
  if tokens > self.max_tokens:
423
- raise TokenLimitError(r"Irreducible node too large for context!")
426
+ raise TokenLimitError(
427
+ "Irreducible node too large for context!"
428
+ f" ({tokens} > {self.max_tokens})"
429
+ )
424
430
 
425
431
  node.children.append(
426
432
  CodeBlock(
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import platform
3
3
  from collections import defaultdict
4
+ from ctypes import c_void_p, cdll
4
5
  from pathlib import Path
5
6
  from typing import Optional
6
7
 
@@ -26,8 +27,8 @@ class TreeSitterSplitter(Splitter):
26
27
  language: str,
27
28
  model: None | BaseLanguageModel = None,
28
29
  max_tokens: int = 4096,
29
- protected_node_types: tuple[str] = (),
30
- prune_node_types: tuple[str] = (),
30
+ protected_node_types: tuple[str, ...] = (),
31
+ prune_node_types: tuple[str, ...] = (),
31
32
  prune_unprotected: bool = False,
32
33
  ) -> None:
33
34
  """Initialize a TreeSitterSplitter instance.
@@ -48,10 +49,10 @@ class TreeSitterSplitter(Splitter):
48
49
  self._load_parser()
49
50
 
50
51
  def _get_ast(self, code: str) -> CodeBlock:
51
- code = bytes(code, "utf-8")
52
- tree = self.parser.parse(code)
52
+ code_bytes = bytes(code, "utf-8")
53
+ tree = self.parser.parse(code_bytes)
53
54
  root = tree.walk().node
54
- root = self._node_to_block(root, code)
55
+ root = self._node_to_block(root, code_bytes)
55
56
  return root
56
57
 
57
58
  # Recursively print tree to view parsed output (dev helper function)
@@ -98,7 +99,7 @@ class TreeSitterSplitter(Splitter):
98
99
 
99
100
  text = node.text.decode()
100
101
  children = [self._node_to_block(child, original_text) for child in node.children]
101
- node = CodeBlock(
102
+ return CodeBlock(
102
103
  id=node.id,
103
104
  name=str(node.id),
104
105
  text=text,
@@ -112,7 +113,6 @@ class TreeSitterSplitter(Splitter):
112
113
  language=self.language,
113
114
  tokens=self._count_tokens(text),
114
115
  )
115
- return node
116
116
 
117
117
  def _load_parser(self) -> None:
118
118
  """Load the parser for the given language.
@@ -139,7 +139,25 @@ class TreeSitterSplitter(Splitter):
139
139
 
140
140
  # Load the parser using the generated .so file
141
141
  self.parser: tree_sitter.Parser = tree_sitter.Parser()
142
- self.parser.set_language(tree_sitter.Language(so_file, self.language))
142
+ pointer = self._so_to_pointer(so_file)
143
+ self.parser.set_language(tree_sitter.Language(pointer, self.language))
144
+
145
+ def _so_to_pointer(self, so_file: str) -> int:
146
+ """Convert the .so file to a pointer.
147
+
148
+ Taken from `treesitter.Language.__init__` to get past deprecated warning.
149
+
150
+ Arguments:
151
+ so_file: The path to the so file for the language.
152
+
153
+ Returns:
154
+ The pointer to the language.
155
+ """
156
+ lib = cdll.LoadLibrary(os.fspath(so_file))
157
+ language_function = getattr(lib, f"tree_sitter_{self.language}")
158
+ language_function.restype = c_void_p
159
+ pointer = language_function()
160
+ return pointer
143
161
 
144
162
  def _create_parser(self, so_file: Path | str) -> None:
145
163
  """Create the parser for the given language.
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
8
8
  from langchain_core.outputs import ChatGeneration, LLMResult
9
9
  from langchain_core.tracers.context import register_configure_hook
10
10
 
11
- from janus.utils.logger import create_logger
11
+ from ..utils.logger import create_logger
12
12
 
13
13
  log = create_logger(__name__)
14
14
 
@@ -35,6 +35,9 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
35
35
  "ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
36
36
  "ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
37
37
  "cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
38
+ "mistral.mistral-7b-instruct-v0:2": {"input": 0.00015, "output": 0.0002},
39
+ "mistral.mixtral-8x7b-instruct-v0:1": {"input": 0.00045, "output": 0.0007},
40
+ "mistral.mistral-large-2402-v1:0": {"input": 0.004, "output": 0.012},
38
41
  }
39
42
 
40
43
 
@@ -103,53 +106,47 @@ class TokenUsageCallbackHandler(BaseCallbackHandler):
103
106
  generation = response.generations[0][0]
104
107
  except IndexError:
105
108
  generation = None
106
- if isinstance(generation, ChatGeneration):
107
- try:
108
- message = generation.message
109
- if isinstance(message, AIMessage):
110
- usage_metadata = message.usage_metadata
111
- else:
112
- usage_metadata = None
113
- except AttributeError:
114
- usage_metadata = None
115
- else:
116
- usage_metadata = None
117
- if usage_metadata:
118
- token_usage = {"total_tokens": usage_metadata["total_tokens"]}
119
- completion_tokens = usage_metadata["output_tokens"]
120
- prompt_tokens = usage_metadata["input_tokens"]
121
- if response.llm_output is None:
122
- # model name (and therefore cost) is unavailable in
123
- # streaming responses
124
- model_name = ""
125
- else:
126
- model_name = response.llm_output.get("model_name", "")
127
109
 
110
+ model_id = ""
111
+ usage_metadata = None
112
+ if hasattr(response, "llm_output") and response.llm_output is not None:
113
+ model_id = response.llm_output.get("model_id", model_id)
114
+ model_id = response.llm_output.get("model_name", model_id)
115
+ usage_metadata = response.llm_output.get("usage", usage_metadata)
116
+ usage_metadata = response.llm_output.get("token_usage", usage_metadata)
117
+ elif isinstance(generation, ChatGeneration):
118
+ if hasattr(generation, "response_metadata"):
119
+ model_id = generation.response_metadata.get("model_id", model_id)
120
+ model_id = generation.response_metadata.get("model_name", model_id)
121
+ usage_metadata = generation.response_metadata.get("usage", usage_metadata)
122
+ elif hasattr(generation, "message"):
123
+ if isinstance(generation.message, AIMessage):
124
+ usage_metadata = generation.message.usage_metadata
125
+
126
+ completion_tokens = 0
127
+ prompt_tokens = 0
128
+ total_tokens = 0
129
+ if usage_metadata:
130
+ prompt_tokens = usage_metadata.get("prompt_tokens", prompt_tokens)
131
+ prompt_tokens = usage_metadata.get("input_tokens", prompt_tokens)
132
+ completion_tokens = usage_metadata.get("completion_tokens", completion_tokens)
133
+ completion_tokens = usage_metadata.get("output_tokens", completion_tokens)
134
+ total_tokens = usage_metadata.get("total_tokens", total_tokens)
128
135
  else:
129
- if response.llm_output is None:
130
- return None
131
-
132
- if "token_usage" not in response.llm_output:
133
- with self._lock:
134
- self.successful_requests += 1
135
- return None
136
-
137
- # compute tokens and cost for this request
138
- token_usage = response.llm_output["token_usage"]
139
- completion_tokens = token_usage.get("completion_tokens", 0)
140
- prompt_tokens = token_usage.get("prompt_tokens", 0)
141
- model_name = response.llm_output.get("model_name", "")
136
+ with self._lock:
137
+ self.successful_requests += 1
138
+ return None
142
139
 
143
140
  total_cost = _get_token_cost(
144
141
  prompt_tokens=prompt_tokens,
145
142
  completion_tokens=completion_tokens,
146
- model_id=model_name,
143
+ model_id=model_id,
147
144
  )
148
145
 
149
146
  # update shared state behind lock
150
147
  with self._lock:
151
148
  self.total_cost += total_cost
152
- self.total_tokens += token_usage.get("total_tokens", 0)
149
+ self.total_tokens += total_tokens
153
150
  self.prompt_tokens += prompt_tokens
154
151
  self.completion_tokens += completion_tokens
155
152
  self.successful_requests += 1
janus/llm/models_info.py CHANGED
@@ -8,18 +8,18 @@ from langchain_community.llms import HuggingFaceTextGenInference
8
8
  from langchain_core.language_models import BaseLanguageModel
9
9
  from langchain_openai import ChatOpenAI
10
10
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
- from janus.prompts.prompt import (
11
+ from ..prompts.prompt import (
13
12
  ChatGptPromptEngine,
14
13
  ClaudePromptEngine,
15
14
  CoherePromptEngine,
16
15
  Llama2PromptEngine,
17
16
  Llama3PromptEngine,
17
+ MistralPromptEngine,
18
18
  PromptEngine,
19
19
  TitanPromptEngine,
20
20
  )
21
-
22
21
  from ..utils.logger import create_logger
22
+ from .model_callbacks import COST_PER_1K_TOKENS
23
23
 
24
24
  log = create_logger(__name__)
25
25
 
@@ -86,12 +86,18 @@ titan_models = [
86
86
  cohere_models = [
87
87
  "bedrock-command-r-plus",
88
88
  ]
89
+ mistral_models = [
90
+ "bedrock-mistral-7b-instruct",
91
+ "bedrock-mistral-large",
92
+ "bedrock-mixtral",
93
+ ]
89
94
  bedrock_models = [
90
95
  *claude_models,
91
96
  *llama2_models,
92
97
  *llama3_models,
93
98
  *titan_models,
94
99
  *cohere_models,
100
+ *mistral_models,
95
101
  ]
96
102
  all_models = [*openai_models, *bedrock_models]
97
103
 
@@ -119,6 +125,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
119
125
  **{m: Llama3PromptEngine for m in llama3_models},
120
126
  **{m: TitanPromptEngine for m in titan_models},
121
127
  **{m: CoherePromptEngine for m in cohere_models},
128
+ **{m: MistralPromptEngine for m in mistral_models},
122
129
  }
123
130
 
124
131
  _open_ai_defaults: dict[str, str] = {
@@ -143,6 +150,9 @@ model_identifiers = {
143
150
  "bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
144
151
  "bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
145
152
  "bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
153
+ "bedrock-mixtral": "mistral.mixtral-8x7b-instruct-v0:1",
154
+ "bedrock-mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
155
+ "bedrock-mistral-large": "mistral.mistral-large-2402-v1:0",
146
156
  }
147
157
 
148
158
  MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
@@ -183,6 +193,9 @@ TOKEN_LIMITS: dict[str, int] = {
183
193
  "ai21.j2-mid-v1": 8192,
184
194
  "ai21.j2-ultra-v1": 8192,
185
195
  "cohere.command-r-plus-v1:0": 128_000,
196
+ "mistral.mixtral-8x7b-instruct-v0:1": 32_000,
197
+ "mistral.mistral-7b-instruct-v0:2": 32_000,
198
+ "mistral.mistral-large-2402-v1:0": 32_000,
186
199
  }
187
200
 
188
201
 
@@ -3,8 +3,7 @@ from unittest.mock import patch
3
3
 
4
4
  import pytest
5
5
 
6
- from janus.llm.models_info import load_model
7
-
6
+ from ...llm.models_info import load_model
8
7
  from ..llm_metrics import llm_evaluate_option, llm_evaluate_ref_option
9
8
 
10
9
 
@@ -40,7 +39,7 @@ class TestLLMMetrics(unittest.TestCase):
40
39
  print("'Hello, world!")
41
40
  """
42
41
 
43
- @patch("janus.llm.models_info.load_model")
42
+ @patch(".llm.models_info.load_model")
44
43
  @patch("janus.metrics.llm_metrics.llm_evaluate")
45
44
  @pytest.mark.llm_eval
46
45
  def test_llm_self_eval_quality(self, mock_llm_evaluate, mock_load_model):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.rouge_score import rouge
3
+ from ..rouge_score import rouge
4
4
 
5
5
 
6
6
  class TestRouge(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.similarity import similarity_score
3
+ from ..similarity import similarity_score
4
4
 
5
5
 
6
6
  class TestSimilarityScore(unittest.TestCase):
@@ -1,10 +1,9 @@
1
1
  import math
2
2
  from typing import List, Optional
3
3
 
4
- from janus.language.block import CodeBlock
5
- from janus.language.treesitter.treesitter import TreeSitterSplitter
6
- from janus.utils.enums import LANGUAGES
7
-
4
+ from ..language.block import CodeBlock
5
+ from ..language.treesitter.treesitter import TreeSitterSplitter
6
+ from ..utils.enums import LANGUAGES
8
7
  from .metric import metric
9
8
 
10
9
 
janus/metrics/metric.py CHANGED
@@ -7,10 +7,9 @@ import click
7
7
  import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
- from janus.llm import load_model
11
- from janus.utils.enums import LANGUAGES
12
- from janus.utils.logger import create_logger
13
-
10
+ from ..llm import load_model
11
+ from ..utils.enums import LANGUAGES
12
+ from ..utils.logger import create_logger
14
13
  from ..utils.progress import track
15
14
  from .cli import evaluate
16
15
  from .file_pairing import FILE_PAIRING_METHODS
janus/metrics/reading.py CHANGED
@@ -1,9 +1,30 @@
1
+ import re
2
+
1
3
  import nltk
2
4
  import readability
5
+ from nltk.tokenize import TweetTokenizer
3
6
 
4
7
  from .metric import metric
5
8
 
6
9
 
10
+ def word_count(text):
11
+ """Calculates word count exactly how readability package does
12
+
13
+ Arguments:
14
+ text: The input string.
15
+
16
+ Returns:
17
+ Word Count
18
+ """
19
+ tokenizer = TweetTokenizer()
20
+ word_count = 0
21
+ tokens = tokenizer.tokenize(text)
22
+ for t in tokens:
23
+ if not re.match(r"^[.,\/#!$%'\^&\*;:{}=\-_`~()]+$", t):
24
+ word_count += 1
25
+ return word_count
26
+
27
+
7
28
  def _repeat_text(text):
8
29
  """Repeats a string until its length is over 100 words.
9
30
 
@@ -20,11 +41,10 @@ def _repeat_text(text):
20
41
  if not text.endswith("."):
21
42
  text += "." # Add a period if missing
22
43
 
23
- # Check if repeated text is long enough, repeat more if needed
24
44
  repeated_text = text
25
- while len(repeated_text.split()) < 100:
26
- repeated_text += " " + text
27
45
 
46
+ while word_count(repeated_text) < 100:
47
+ repeated_text += " " + text
28
48
  return repeated_text
29
49
 
30
50
 
@@ -52,7 +72,8 @@ def flesch(target: str, **kwargs) -> float:
52
72
  Returns:
53
73
  The Flesch score.
54
74
  """
55
-
75
+ if not target.strip(): # Check if the target text is blank
76
+ return None
56
77
  return get_readability(target).flesch().score
57
78
 
58
79
 
@@ -66,5 +87,6 @@ def gunning_fog(target: str, **kwargs) -> float:
66
87
  Returns:
67
88
  The Gunning-Fog score.
68
89
  """
69
-
90
+ if not target.strip(): # Check if the target text is blank
91
+ return None
70
92
  return get_readability(target).gunning_fog().score
janus/prompts/prompt.py CHANGED
@@ -2,12 +2,12 @@ import json
2
2
  from abc import ABC, abstractmethod
3
3
  from pathlib import Path
4
4
 
5
- from langchain import PromptTemplate
6
5
  from langchain.prompts import ChatPromptTemplate
7
6
  from langchain.prompts.chat import (
8
7
  HumanMessagePromptTemplate,
9
8
  SystemMessagePromptTemplate,
10
9
  )
10
+ from langchain_core.prompts import PromptTemplate
11
11
 
12
12
  from ..utils.enums import LANGUAGES
13
13
  from ..utils.logger import create_logger
@@ -34,15 +34,59 @@ HUMAN_PROMPT_TEMPLATE_FILENAME = "human.txt"
34
34
  PROMPT_VARIABLES_FILENAME = "variables.json"
35
35
 
36
36
 
37
+ retry_with_output_prompt_text = """Instructions:
38
+ --------------
39
+ {instructions}
40
+ --------------
41
+ Completion:
42
+ --------------
43
+ {completion}
44
+ --------------
45
+
46
+ Above, the Completion did not satisfy the constraints given in the Instructions.
47
+ Error:
48
+ --------------
49
+ {error}
50
+ --------------
51
+
52
+ Please try again. Please only respond with an answer that satisfies the
53
+ constraints laid out in the Instructions:"""
54
+
55
+
56
+ retry_with_error_and_output_prompt_text = """Prompt:
57
+ --------------
58
+ {prompt}
59
+ --------------
60
+ Completion:
61
+ --------------
62
+ {completion}
63
+ --------------
64
+
65
+ Above, the Completion did not satisfy the constraints given in the Prompt.
66
+ Error:
67
+ --------------
68
+ {error}
69
+ --------------
70
+
71
+ Please try again. Please only respond with an answer that satisfies the
72
+ constraints laid out in the Prompt:"""
73
+
74
+
75
+ retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
76
+ retry_with_error_and_output_prompt = PromptTemplate.from_template(
77
+ retry_with_error_and_output_prompt_text
78
+ )
79
+
80
+
37
81
  class PromptEngine(ABC):
38
82
  """A class defining prompting schemes for the LLM."""
39
83
 
40
84
  def __init__(
41
85
  self,
42
86
  source_language: str,
43
- target_language: str,
44
- target_version: str,
45
87
  prompt_template: str,
88
+ target_language: str | None = None,
89
+ target_version: str | None = None,
46
90
  ) -> None:
47
91
  """Initialize a PromptEngine instance.
48
92
 
@@ -63,15 +107,18 @@ class PromptEngine(ABC):
63
107
 
64
108
  # Define variables to be passed in to the prompt formatter
65
109
  source_language = source_language.lower()
66
- target_language = target_language.lower()
67
110
  self.variables = dict(
68
111
  SOURCE_LANGUAGE=source_language,
69
- TARGET_LANGUAGE=target_language,
70
- TARGET_LANGUAGE_VERSION=str(target_version),
71
112
  FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
72
113
  SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
73
- TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
74
114
  )
115
+ if target_language is not None:
116
+ target_language = target_language.lower()
117
+ self.variables.update(
118
+ TARGET_LANGUAGE=target_language,
119
+ TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
120
+ )
121
+ self.variables.update(TARGET_LANGUAGE_VERSION=str(target_version))
75
122
  variables_path = template_path / PROMPT_VARIABLES_FILENAME
76
123
  if variables_path.exists():
77
124
  self.variables.update(json.loads(variables_path.read_text()))
@@ -219,3 +266,16 @@ class CoherePromptEngine(PromptEngine):
219
266
  f"{human_prompt}"
220
267
  f"<|END_OF_TURN_TOKEN|>"
221
268
  )
269
+
270
+
271
+ class MistralPromptEngine(PromptEngine):
272
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
273
+ system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
274
+ system_prompt = system_prompt_path.read_text()
275
+
276
+ human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
277
+ human_prompt = human_prompt_path.read_text()
278
+
279
+ return PromptTemplate.from_template(
280
+ f"<s>[INST] {system_prompt} [/INST] </s>[INST] {human_prompt} [/INST]"
281
+ )
janus/utils/enums.py CHANGED
@@ -10,7 +10,7 @@ class EmbeddingType(Enum):
10
10
  TARGET = 5 # placeholder embeddings, are these useful for analysis?
11
11
 
12
12
 
13
- CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary"}
13
+ CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary", "ibmhlasm"}
14
14
 
15
15
  LANGUAGES: Dict[str, Dict[str, Any]] = {
16
16
  "ada": {
@@ -63,7 +63,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
63
63
  '#include <stdio.h>\n\nint main() {\n printf("Hello, World!\\n");\n'
64
64
  " return 0;\n}\n"
65
65
  ),
66
- "functional_node_type": "function_definition",
66
+ "functional_node_types": ["function_definition"],
67
67
  "comment_node_type": "comment",
68
68
  },
69
69
  "capnp": {
@@ -206,7 +206,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
206
206
  "example": (
207
207
  "program HelloWorld\n print *, 'Hello, World!'\nend program HelloWorld\n"
208
208
  ),
209
- "functional_node_type": "function",
209
+ "functional_node_types": ["function"],
210
210
  "comment_node_type": "comment",
211
211
  },
212
212
  "gitattributes": {
@@ -300,6 +300,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
300
300
  END HELLO
301
301
  """
302
302
  ),
303
+ "functional_node_types": ["csect", "dsect"],
303
304
  "branch_node_types": ["branch_instruction"],
304
305
  "operation_node_types": ["operation", "branch_operation"],
305
306
  "operand_node_types": ["operands"],
@@ -420,7 +421,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
420
421
  "suffix": "m",
421
422
  "url": "https://github.com/janus-llm/tree-sitter-mumps",
422
423
  "example": 'WRITE "Hello, World!"',
423
- "functional_node_type": "routine_definition",
424
+ "functional_node_types": ["routine_definition"],
424
425
  "comment_node_type": "comment",
425
426
  "branch_node_types": ["if_statement"],
426
427
  "operation_node_types": [
@@ -512,7 +513,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
512
513
  "suffix": "py",
513
514
  "url": "https://github.com/tree-sitter/tree-sitter-python",
514
515
  "example": "# Hello, World!\nprint('Hello, World!')\n",
515
- "functional_node_type": "function_definition",
516
+ "functional_node_types": ["function_definition"],
516
517
  "comment_node_type": "comment",
517
518
  },
518
519
  "qmljs": {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 2.0.2
3
+ Version: 3.0.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0