janus-llm 2.0.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. janus/__init__.py +2 -2
  2. janus/__main__.py +1 -1
  3. janus/_tests/test_cli.py +1 -2
  4. janus/cli.py +43 -51
  5. janus/converter/__init__.py +6 -0
  6. janus/converter/_tests/__init__.py +0 -0
  7. janus/{_tests → converter/_tests}/test_translate.py +11 -22
  8. janus/converter/converter.py +614 -0
  9. janus/converter/diagram.py +124 -0
  10. janus/converter/document.py +131 -0
  11. janus/converter/evaluate.py +15 -0
  12. janus/converter/requirements.py +50 -0
  13. janus/converter/translate.py +108 -0
  14. janus/embedding/_tests/test_collections.py +2 -2
  15. janus/language/_tests/test_splitter.py +1 -1
  16. janus/language/alc/__init__.py +1 -0
  17. janus/language/alc/_tests/__init__.py +0 -0
  18. janus/language/alc/_tests/test_alc.py +28 -0
  19. janus/language/alc/alc.py +87 -0
  20. janus/language/block.py +4 -2
  21. janus/language/combine.py +0 -1
  22. janus/language/mumps/mumps.py +2 -3
  23. janus/language/naive/__init__.py +1 -1
  24. janus/language/naive/basic_splitter.py +4 -4
  25. janus/language/naive/chunk_splitter.py +4 -4
  26. janus/language/naive/registry.py +1 -1
  27. janus/language/naive/simple_ast.py +23 -12
  28. janus/language/naive/tag_splitter.py +4 -4
  29. janus/language/splitter.py +10 -4
  30. janus/language/treesitter/treesitter.py +26 -8
  31. janus/llm/model_callbacks.py +34 -37
  32. janus/llm/models_info.py +16 -3
  33. janus/metrics/_tests/test_llm.py +2 -3
  34. janus/metrics/_tests/test_rouge_score.py +1 -1
  35. janus/metrics/_tests/test_similarity_score.py +1 -1
  36. janus/metrics/complexity_metrics.py +3 -4
  37. janus/metrics/metric.py +3 -4
  38. janus/metrics/reading.py +27 -5
  39. janus/prompts/prompt.py +67 -7
  40. janus/utils/enums.py +6 -5
  41. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/METADATA +1 -1
  42. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/RECORD +45 -35
  43. janus/converter.py +0 -158
  44. janus/translate.py +0 -981
  45. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/LICENSE +0 -0
  46. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/WHEEL +0 -0
  47. {janus_llm-2.0.2.dist-info → janus_llm-3.0.0.dist-info}/entry_points.txt +0 -0
@@ -1,18 +1,29 @@
1
- from janus.language.naive.registry import register_splitter
2
- from janus.language.treesitter import TreeSitterSplitter
3
- from janus.utils.enums import LANGUAGES
1
+ from ...utils.enums import LANGUAGES
2
+ from ..alc.alc import AlcSplitter
3
+ from ..mumps.mumps import MumpsSplitter
4
+ from ..treesitter import TreeSitterSplitter
5
+ from .registry import register_splitter
4
6
 
5
7
 
6
8
  @register_splitter("ast-flex")
7
- class FlexibleTreeSitterSplitter(TreeSitterSplitter):
8
- pass
9
+ def get_flexible_ast(language: str, **kwargs):
10
+ if language == "ibmhlasm":
11
+ return AlcSplitter(**kwargs)
12
+ elif language == "mumps":
13
+ return MumpsSplitter(**kwargs)
14
+ else:
15
+ return TreeSitterSplitter(language=language, **kwargs)
9
16
 
10
17
 
11
18
  @register_splitter("ast-strict")
12
- class StrictTreeSitterSplitter(TreeSitterSplitter):
13
- def __init__(self, language: str, **kwargs):
14
- kwargs.update(
15
- protected_node_types=(LANGUAGES[language]["functional_node_type"],),
16
- prune_unprotected=True,
17
- )
18
- super().__init__(language=language, **kwargs)
19
+ def get_strict_ast(language: str, **kwargs):
20
+ kwargs.update(
21
+ protected_node_types=LANGUAGES[language]["functional_node_types"],
22
+ prune_unprotected=True,
23
+ )
24
+ if language == "ibmhlasm":
25
+ return AlcSplitter(**kwargs)
26
+ elif language == "mumps":
27
+ return MumpsSplitter(**kwargs)
28
+ else:
29
+ return TreeSitterSplitter(language=language, **kwargs)
@@ -1,7 +1,7 @@
1
- from janus.language.block import CodeBlock
2
- from janus.language.naive.registry import register_splitter
3
- from janus.language.node import NodeType
4
- from janus.language.splitter import Splitter
1
+ from ..block import CodeBlock
2
+ from ..node import NodeType
3
+ from ..splitter import Splitter
4
+ from .registry import register_splitter
5
5
 
6
6
 
7
7
  @register_splitter("tag")
@@ -47,8 +47,8 @@ class Splitter(FileManager):
47
47
  model: None | BaseLanguageModel = None,
48
48
  max_tokens: int = 4096,
49
49
  skip_merge: bool = False,
50
- protected_node_types: tuple[str] = (),
51
- prune_node_types: tuple[str] = (),
50
+ protected_node_types: tuple[str, ...] = (),
51
+ prune_node_types: tuple[str, ...] = (),
52
52
  prune_unprotected: bool = False,
53
53
  ):
54
54
  """
@@ -340,7 +340,10 @@ class Splitter(FileManager):
340
340
  # Double check length (in theory this should never be an issue)
341
341
  tokens = self._count_tokens(text)
342
342
  if tokens > self.max_tokens:
343
- log.error(f"Merged node ({name}) too long for context!")
343
+ log.error(
344
+ f"Merged node ({name}) too long for context!"
345
+ f" ({tokens} > {self.max_tokens})"
346
+ )
344
347
 
345
348
  return CodeBlock(
346
349
  text=text,
@@ -420,7 +423,10 @@ class Splitter(FileManager):
420
423
  name = f"{node.name}-L#{node_line}"
421
424
  tokens = self._count_tokens(line)
422
425
  if tokens > self.max_tokens:
423
- raise TokenLimitError(r"Irreducible node too large for context!")
426
+ raise TokenLimitError(
427
+ "Irreducible node too large for context!"
428
+ f" ({tokens} > {self.max_tokens})"
429
+ )
424
430
 
425
431
  node.children.append(
426
432
  CodeBlock(
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import platform
3
3
  from collections import defaultdict
4
+ from ctypes import c_void_p, cdll
4
5
  from pathlib import Path
5
6
  from typing import Optional
6
7
 
@@ -26,8 +27,8 @@ class TreeSitterSplitter(Splitter):
26
27
  language: str,
27
28
  model: None | BaseLanguageModel = None,
28
29
  max_tokens: int = 4096,
29
- protected_node_types: tuple[str] = (),
30
- prune_node_types: tuple[str] = (),
30
+ protected_node_types: tuple[str, ...] = (),
31
+ prune_node_types: tuple[str, ...] = (),
31
32
  prune_unprotected: bool = False,
32
33
  ) -> None:
33
34
  """Initialize a TreeSitterSplitter instance.
@@ -48,10 +49,10 @@ class TreeSitterSplitter(Splitter):
48
49
  self._load_parser()
49
50
 
50
51
  def _get_ast(self, code: str) -> CodeBlock:
51
- code = bytes(code, "utf-8")
52
- tree = self.parser.parse(code)
52
+ code_bytes = bytes(code, "utf-8")
53
+ tree = self.parser.parse(code_bytes)
53
54
  root = tree.walk().node
54
- root = self._node_to_block(root, code)
55
+ root = self._node_to_block(root, code_bytes)
55
56
  return root
56
57
 
57
58
  # Recursively print tree to view parsed output (dev helper function)
@@ -98,7 +99,7 @@ class TreeSitterSplitter(Splitter):
98
99
 
99
100
  text = node.text.decode()
100
101
  children = [self._node_to_block(child, original_text) for child in node.children]
101
- node = CodeBlock(
102
+ return CodeBlock(
102
103
  id=node.id,
103
104
  name=str(node.id),
104
105
  text=text,
@@ -112,7 +113,6 @@ class TreeSitterSplitter(Splitter):
112
113
  language=self.language,
113
114
  tokens=self._count_tokens(text),
114
115
  )
115
- return node
116
116
 
117
117
  def _load_parser(self) -> None:
118
118
  """Load the parser for the given language.
@@ -139,7 +139,25 @@ class TreeSitterSplitter(Splitter):
139
139
 
140
140
  # Load the parser using the generated .so file
141
141
  self.parser: tree_sitter.Parser = tree_sitter.Parser()
142
- self.parser.set_language(tree_sitter.Language(so_file, self.language))
142
+ pointer = self._so_to_pointer(so_file)
143
+ self.parser.set_language(tree_sitter.Language(pointer, self.language))
144
+
145
+ def _so_to_pointer(self, so_file: str) -> int:
146
+ """Convert the .so file to a pointer.
147
+
148
+ Taken from `treesitter.Language.__init__` to get past deprecated warning.
149
+
150
+ Arguments:
151
+ so_file: The path to the so file for the language.
152
+
153
+ Returns:
154
+ The pointer to the language.
155
+ """
156
+ lib = cdll.LoadLibrary(os.fspath(so_file))
157
+ language_function = getattr(lib, f"tree_sitter_{self.language}")
158
+ language_function.restype = c_void_p
159
+ pointer = language_function()
160
+ return pointer
143
161
 
144
162
  def _create_parser(self, so_file: Path | str) -> None:
145
163
  """Create the parser for the given language.
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
8
8
  from langchain_core.outputs import ChatGeneration, LLMResult
9
9
  from langchain_core.tracers.context import register_configure_hook
10
10
 
11
- from janus.utils.logger import create_logger
11
+ from ..utils.logger import create_logger
12
12
 
13
13
  log = create_logger(__name__)
14
14
 
@@ -35,6 +35,9 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
35
35
  "ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
36
36
  "ai21.j2-ultra-v1": {"input": 0.0188, "output": 0.0188},
37
37
  "cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015},
38
+ "mistral.mistral-7b-instruct-v0:2": {"input": 0.00015, "output": 0.0002},
39
+ "mistral.mixtral-8x7b-instruct-v0:1": {"input": 0.00045, "output": 0.0007},
40
+ "mistral.mistral-large-2402-v1:0": {"input": 0.004, "output": 0.012},
38
41
  }
39
42
 
40
43
 
@@ -103,53 +106,47 @@ class TokenUsageCallbackHandler(BaseCallbackHandler):
103
106
  generation = response.generations[0][0]
104
107
  except IndexError:
105
108
  generation = None
106
- if isinstance(generation, ChatGeneration):
107
- try:
108
- message = generation.message
109
- if isinstance(message, AIMessage):
110
- usage_metadata = message.usage_metadata
111
- else:
112
- usage_metadata = None
113
- except AttributeError:
114
- usage_metadata = None
115
- else:
116
- usage_metadata = None
117
- if usage_metadata:
118
- token_usage = {"total_tokens": usage_metadata["total_tokens"]}
119
- completion_tokens = usage_metadata["output_tokens"]
120
- prompt_tokens = usage_metadata["input_tokens"]
121
- if response.llm_output is None:
122
- # model name (and therefore cost) is unavailable in
123
- # streaming responses
124
- model_name = ""
125
- else:
126
- model_name = response.llm_output.get("model_name", "")
127
109
 
110
+ model_id = ""
111
+ usage_metadata = None
112
+ if hasattr(response, "llm_output") and response.llm_output is not None:
113
+ model_id = response.llm_output.get("model_id", model_id)
114
+ model_id = response.llm_output.get("model_name", model_id)
115
+ usage_metadata = response.llm_output.get("usage", usage_metadata)
116
+ usage_metadata = response.llm_output.get("token_usage", usage_metadata)
117
+ elif isinstance(generation, ChatGeneration):
118
+ if hasattr(generation, "response_metadata"):
119
+ model_id = generation.response_metadata.get("model_id", model_id)
120
+ model_id = generation.response_metadata.get("model_name", model_id)
121
+ usage_metadata = generation.response_metadata.get("usage", usage_metadata)
122
+ elif hasattr(generation, "message"):
123
+ if isinstance(generation.message, AIMessage):
124
+ usage_metadata = generation.message.usage_metadata
125
+
126
+ completion_tokens = 0
127
+ prompt_tokens = 0
128
+ total_tokens = 0
129
+ if usage_metadata:
130
+ prompt_tokens = usage_metadata.get("prompt_tokens", prompt_tokens)
131
+ prompt_tokens = usage_metadata.get("input_tokens", prompt_tokens)
132
+ completion_tokens = usage_metadata.get("completion_tokens", completion_tokens)
133
+ completion_tokens = usage_metadata.get("output_tokens", completion_tokens)
134
+ total_tokens = usage_metadata.get("total_tokens", total_tokens)
128
135
  else:
129
- if response.llm_output is None:
130
- return None
131
-
132
- if "token_usage" not in response.llm_output:
133
- with self._lock:
134
- self.successful_requests += 1
135
- return None
136
-
137
- # compute tokens and cost for this request
138
- token_usage = response.llm_output["token_usage"]
139
- completion_tokens = token_usage.get("completion_tokens", 0)
140
- prompt_tokens = token_usage.get("prompt_tokens", 0)
141
- model_name = response.llm_output.get("model_name", "")
136
+ with self._lock:
137
+ self.successful_requests += 1
138
+ return None
142
139
 
143
140
  total_cost = _get_token_cost(
144
141
  prompt_tokens=prompt_tokens,
145
142
  completion_tokens=completion_tokens,
146
- model_id=model_name,
143
+ model_id=model_id,
147
144
  )
148
145
 
149
146
  # update shared state behind lock
150
147
  with self._lock:
151
148
  self.total_cost += total_cost
152
- self.total_tokens += token_usage.get("total_tokens", 0)
149
+ self.total_tokens += total_tokens
153
150
  self.prompt_tokens += prompt_tokens
154
151
  self.completion_tokens += completion_tokens
155
152
  self.successful_requests += 1
janus/llm/models_info.py CHANGED
@@ -8,18 +8,18 @@ from langchain_community.llms import HuggingFaceTextGenInference
8
8
  from langchain_core.language_models import BaseLanguageModel
9
9
  from langchain_openai import ChatOpenAI
10
10
 
11
- from janus.llm.model_callbacks import COST_PER_1K_TOKENS
12
- from janus.prompts.prompt import (
11
+ from ..prompts.prompt import (
13
12
  ChatGptPromptEngine,
14
13
  ClaudePromptEngine,
15
14
  CoherePromptEngine,
16
15
  Llama2PromptEngine,
17
16
  Llama3PromptEngine,
17
+ MistralPromptEngine,
18
18
  PromptEngine,
19
19
  TitanPromptEngine,
20
20
  )
21
-
22
21
  from ..utils.logger import create_logger
22
+ from .model_callbacks import COST_PER_1K_TOKENS
23
23
 
24
24
  log = create_logger(__name__)
25
25
 
@@ -86,12 +86,18 @@ titan_models = [
86
86
  cohere_models = [
87
87
  "bedrock-command-r-plus",
88
88
  ]
89
+ mistral_models = [
90
+ "bedrock-mistral-7b-instruct",
91
+ "bedrock-mistral-large",
92
+ "bedrock-mixtral",
93
+ ]
89
94
  bedrock_models = [
90
95
  *claude_models,
91
96
  *llama2_models,
92
97
  *llama3_models,
93
98
  *titan_models,
94
99
  *cohere_models,
100
+ *mistral_models,
95
101
  ]
96
102
  all_models = [*openai_models, *bedrock_models]
97
103
 
@@ -119,6 +125,7 @@ MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
119
125
  **{m: Llama3PromptEngine for m in llama3_models},
120
126
  **{m: TitanPromptEngine for m in titan_models},
121
127
  **{m: CoherePromptEngine for m in cohere_models},
128
+ **{m: MistralPromptEngine for m in mistral_models},
122
129
  }
123
130
 
124
131
  _open_ai_defaults: dict[str, str] = {
@@ -143,6 +150,9 @@ model_identifiers = {
143
150
  "bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
144
151
  "bedrock-jurassic-2-ultra": "ai21.j2-ultra-v1",
145
152
  "bedrock-command-r-plus": "cohere.command-r-plus-v1:0",
153
+ "bedrock-mixtral": "mistral.mixtral-8x7b-instruct-v0:1",
154
+ "bedrock-mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
155
+ "bedrock-mistral-large": "mistral.mistral-large-2402-v1:0",
146
156
  }
147
157
 
148
158
  MODEL_DEFAULT_ARGUMENTS: dict[str, dict[str, str]] = {
@@ -183,6 +193,9 @@ TOKEN_LIMITS: dict[str, int] = {
183
193
  "ai21.j2-mid-v1": 8192,
184
194
  "ai21.j2-ultra-v1": 8192,
185
195
  "cohere.command-r-plus-v1:0": 128_000,
196
+ "mistral.mixtral-8x7b-instruct-v0:1": 32_000,
197
+ "mistral.mistral-7b-instruct-v0:2": 32_000,
198
+ "mistral.mistral-large-2402-v1:0": 32_000,
186
199
  }
187
200
 
188
201
 
@@ -3,8 +3,7 @@ from unittest.mock import patch
3
3
 
4
4
  import pytest
5
5
 
6
- from janus.llm.models_info import load_model
7
-
6
+ from ...llm.models_info import load_model
8
7
  from ..llm_metrics import llm_evaluate_option, llm_evaluate_ref_option
9
8
 
10
9
 
@@ -40,7 +39,7 @@ class TestLLMMetrics(unittest.TestCase):
40
39
  print("'Hello, world!")
41
40
  """
42
41
 
43
- @patch("janus.llm.models_info.load_model")
42
+ @patch(".llm.models_info.load_model")
44
43
  @patch("janus.metrics.llm_metrics.llm_evaluate")
45
44
  @pytest.mark.llm_eval
46
45
  def test_llm_self_eval_quality(self, mock_llm_evaluate, mock_load_model):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.rouge_score import rouge
3
+ from ..rouge_score import rouge
4
4
 
5
5
 
6
6
  class TestRouge(unittest.TestCase):
@@ -1,6 +1,6 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.similarity import similarity_score
3
+ from ..similarity import similarity_score
4
4
 
5
5
 
6
6
  class TestSimilarityScore(unittest.TestCase):
@@ -1,10 +1,9 @@
1
1
  import math
2
2
  from typing import List, Optional
3
3
 
4
- from janus.language.block import CodeBlock
5
- from janus.language.treesitter.treesitter import TreeSitterSplitter
6
- from janus.utils.enums import LANGUAGES
7
-
4
+ from ..language.block import CodeBlock
5
+ from ..language.treesitter.treesitter import TreeSitterSplitter
6
+ from ..utils.enums import LANGUAGES
8
7
  from .metric import metric
9
8
 
10
9
 
janus/metrics/metric.py CHANGED
@@ -7,10 +7,9 @@ import click
7
7
  import typer
8
8
  from typing_extensions import Annotated
9
9
 
10
- from janus.llm import load_model
11
- from janus.utils.enums import LANGUAGES
12
- from janus.utils.logger import create_logger
13
-
10
+ from ..llm import load_model
11
+ from ..utils.enums import LANGUAGES
12
+ from ..utils.logger import create_logger
14
13
  from ..utils.progress import track
15
14
  from .cli import evaluate
16
15
  from .file_pairing import FILE_PAIRING_METHODS
janus/metrics/reading.py CHANGED
@@ -1,9 +1,30 @@
1
+ import re
2
+
1
3
  import nltk
2
4
  import readability
5
+ from nltk.tokenize import TweetTokenizer
3
6
 
4
7
  from .metric import metric
5
8
 
6
9
 
10
+ def word_count(text):
11
+ """Calculates word count exactly how readability package does
12
+
13
+ Arguments:
14
+ text: The input string.
15
+
16
+ Returns:
17
+ Word Count
18
+ """
19
+ tokenizer = TweetTokenizer()
20
+ word_count = 0
21
+ tokens = tokenizer.tokenize(text)
22
+ for t in tokens:
23
+ if not re.match(r"^[.,\/#!$%'\^&\*;:{}=\-_`~()]+$", t):
24
+ word_count += 1
25
+ return word_count
26
+
27
+
7
28
  def _repeat_text(text):
8
29
  """Repeats a string until its length is over 100 words.
9
30
 
@@ -20,11 +41,10 @@ def _repeat_text(text):
20
41
  if not text.endswith("."):
21
42
  text += "." # Add a period if missing
22
43
 
23
- # Check if repeated text is long enough, repeat more if needed
24
44
  repeated_text = text
25
- while len(repeated_text.split()) < 100:
26
- repeated_text += " " + text
27
45
 
46
+ while word_count(repeated_text) < 100:
47
+ repeated_text += " " + text
28
48
  return repeated_text
29
49
 
30
50
 
@@ -52,7 +72,8 @@ def flesch(target: str, **kwargs) -> float:
52
72
  Returns:
53
73
  The Flesch score.
54
74
  """
55
-
75
+ if not target.strip(): # Check if the target text is blank
76
+ return None
56
77
  return get_readability(target).flesch().score
57
78
 
58
79
 
@@ -66,5 +87,6 @@ def gunning_fog(target: str, **kwargs) -> float:
66
87
  Returns:
67
88
  The Gunning-Fog score.
68
89
  """
69
-
90
+ if not target.strip(): # Check if the target text is blank
91
+ return None
70
92
  return get_readability(target).gunning_fog().score
janus/prompts/prompt.py CHANGED
@@ -2,12 +2,12 @@ import json
2
2
  from abc import ABC, abstractmethod
3
3
  from pathlib import Path
4
4
 
5
- from langchain import PromptTemplate
6
5
  from langchain.prompts import ChatPromptTemplate
7
6
  from langchain.prompts.chat import (
8
7
  HumanMessagePromptTemplate,
9
8
  SystemMessagePromptTemplate,
10
9
  )
10
+ from langchain_core.prompts import PromptTemplate
11
11
 
12
12
  from ..utils.enums import LANGUAGES
13
13
  from ..utils.logger import create_logger
@@ -34,15 +34,59 @@ HUMAN_PROMPT_TEMPLATE_FILENAME = "human.txt"
34
34
  PROMPT_VARIABLES_FILENAME = "variables.json"
35
35
 
36
36
 
37
+ retry_with_output_prompt_text = """Instructions:
38
+ --------------
39
+ {instructions}
40
+ --------------
41
+ Completion:
42
+ --------------
43
+ {completion}
44
+ --------------
45
+
46
+ Above, the Completion did not satisfy the constraints given in the Instructions.
47
+ Error:
48
+ --------------
49
+ {error}
50
+ --------------
51
+
52
+ Please try again. Please only respond with an answer that satisfies the
53
+ constraints laid out in the Instructions:"""
54
+
55
+
56
+ retry_with_error_and_output_prompt_text = """Prompt:
57
+ --------------
58
+ {prompt}
59
+ --------------
60
+ Completion:
61
+ --------------
62
+ {completion}
63
+ --------------
64
+
65
+ Above, the Completion did not satisfy the constraints given in the Prompt.
66
+ Error:
67
+ --------------
68
+ {error}
69
+ --------------
70
+
71
+ Please try again. Please only respond with an answer that satisfies the
72
+ constraints laid out in the Prompt:"""
73
+
74
+
75
+ retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
76
+ retry_with_error_and_output_prompt = PromptTemplate.from_template(
77
+ retry_with_error_and_output_prompt_text
78
+ )
79
+
80
+
37
81
  class PromptEngine(ABC):
38
82
  """A class defining prompting schemes for the LLM."""
39
83
 
40
84
  def __init__(
41
85
  self,
42
86
  source_language: str,
43
- target_language: str,
44
- target_version: str,
45
87
  prompt_template: str,
88
+ target_language: str | None = None,
89
+ target_version: str | None = None,
46
90
  ) -> None:
47
91
  """Initialize a PromptEngine instance.
48
92
 
@@ -63,15 +107,18 @@ class PromptEngine(ABC):
63
107
 
64
108
  # Define variables to be passed in to the prompt formatter
65
109
  source_language = source_language.lower()
66
- target_language = target_language.lower()
67
110
  self.variables = dict(
68
111
  SOURCE_LANGUAGE=source_language,
69
- TARGET_LANGUAGE=target_language,
70
- TARGET_LANGUAGE_VERSION=str(target_version),
71
112
  FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
72
113
  SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
73
- TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
74
114
  )
115
+ if target_language is not None:
116
+ target_language = target_language.lower()
117
+ self.variables.update(
118
+ TARGET_LANGUAGE=target_language,
119
+ TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
120
+ )
121
+ self.variables.update(TARGET_LANGUAGE_VERSION=str(target_version))
75
122
  variables_path = template_path / PROMPT_VARIABLES_FILENAME
76
123
  if variables_path.exists():
77
124
  self.variables.update(json.loads(variables_path.read_text()))
@@ -219,3 +266,16 @@ class CoherePromptEngine(PromptEngine):
219
266
  f"{human_prompt}"
220
267
  f"<|END_OF_TURN_TOKEN|>"
221
268
  )
269
+
270
+
271
+ class MistralPromptEngine(PromptEngine):
272
+ def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
273
+ system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
274
+ system_prompt = system_prompt_path.read_text()
275
+
276
+ human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
277
+ human_prompt = human_prompt_path.read_text()
278
+
279
+ return PromptTemplate.from_template(
280
+ f"<s>[INST] {system_prompt} [/INST] </s>[INST] {human_prompt} [/INST]"
281
+ )
janus/utils/enums.py CHANGED
@@ -10,7 +10,7 @@ class EmbeddingType(Enum):
10
10
  TARGET = 5 # placeholder embeddings, are these useful for analysis?
11
11
 
12
12
 
13
- CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary"}
13
+ CUSTOM_SPLITTERS: Set[str] = {"mumps", "binary", "ibmhlasm"}
14
14
 
15
15
  LANGUAGES: Dict[str, Dict[str, Any]] = {
16
16
  "ada": {
@@ -63,7 +63,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
63
63
  '#include <stdio.h>\n\nint main() {\n printf("Hello, World!\\n");\n'
64
64
  " return 0;\n}\n"
65
65
  ),
66
- "functional_node_type": "function_definition",
66
+ "functional_node_types": ["function_definition"],
67
67
  "comment_node_type": "comment",
68
68
  },
69
69
  "capnp": {
@@ -206,7 +206,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
206
206
  "example": (
207
207
  "program HelloWorld\n print *, 'Hello, World!'\nend program HelloWorld\n"
208
208
  ),
209
- "functional_node_type": "function",
209
+ "functional_node_types": ["function"],
210
210
  "comment_node_type": "comment",
211
211
  },
212
212
  "gitattributes": {
@@ -300,6 +300,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
300
300
  END HELLO
301
301
  """
302
302
  ),
303
+ "functional_node_types": ["csect", "dsect"],
303
304
  "branch_node_types": ["branch_instruction"],
304
305
  "operation_node_types": ["operation", "branch_operation"],
305
306
  "operand_node_types": ["operands"],
@@ -420,7 +421,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
420
421
  "suffix": "m",
421
422
  "url": "https://github.com/janus-llm/tree-sitter-mumps",
422
423
  "example": 'WRITE "Hello, World!"',
423
- "functional_node_type": "routine_definition",
424
+ "functional_node_types": ["routine_definition"],
424
425
  "comment_node_type": "comment",
425
426
  "branch_node_types": ["if_statement"],
426
427
  "operation_node_types": [
@@ -512,7 +513,7 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
512
513
  "suffix": "py",
513
514
  "url": "https://github.com/tree-sitter/tree-sitter-python",
514
515
  "example": "# Hello, World!\nprint('Hello, World!')\n",
515
- "functional_node_type": "function_definition",
516
+ "functional_node_types": ["function_definition"],
516
517
  "comment_node_type": "comment",
517
518
  },
518
519
  "qmljs": {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: janus-llm
3
- Version: 2.0.2
3
+ Version: 3.0.0
4
4
  Summary: A transcoding library using LLMs.
5
5
  Home-page: https://github.com/janus-llm/janus-llm
6
6
  License: Apache 2.0