janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -50
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +51 -0
- janus/converter/translate.py +108 -0
- janus/language/block.py +1 -1
- janus/language/combine.py +0 -1
- janus/language/treesitter/treesitter.py +20 -1
- janus/llm/model_callbacks.py +33 -36
- janus/llm/models_info.py +14 -0
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +37 -11
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/METADATA +1 -1
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/RECORD +25 -19
- janus/converter.py +0 -161
- janus/translate.py +0 -987
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/WHEEL +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/entry_points.txt +0 -0
janus/metrics/reading.py
CHANGED
@@ -1,9 +1,30 @@
|
|
1
|
+
import re
|
2
|
+
|
1
3
|
import nltk
|
2
4
|
import readability
|
5
|
+
from nltk.tokenize import TweetTokenizer
|
3
6
|
|
4
7
|
from .metric import metric
|
5
8
|
|
6
9
|
|
10
|
+
def word_count(text):
|
11
|
+
"""Calculates word count exactly how readability package does
|
12
|
+
|
13
|
+
Arguments:
|
14
|
+
text: The input string.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
Word Count
|
18
|
+
"""
|
19
|
+
tokenizer = TweetTokenizer()
|
20
|
+
word_count = 0
|
21
|
+
tokens = tokenizer.tokenize(text)
|
22
|
+
for t in tokens:
|
23
|
+
if not re.match(r"^[.,\/#!$%'\^&\*;:{}=\-_`~()]+$", t):
|
24
|
+
word_count += 1
|
25
|
+
return word_count
|
26
|
+
|
27
|
+
|
7
28
|
def _repeat_text(text):
|
8
29
|
"""Repeats a string until its length is over 100 words.
|
9
30
|
|
@@ -20,11 +41,10 @@ def _repeat_text(text):
|
|
20
41
|
if not text.endswith("."):
|
21
42
|
text += "." # Add a period if missing
|
22
43
|
|
23
|
-
# Check if repeated text is long enough, repeat more if needed
|
24
44
|
repeated_text = text
|
25
|
-
while len(repeated_text.split()) < 100:
|
26
|
-
repeated_text += " " + text
|
27
45
|
|
46
|
+
while word_count(repeated_text) < 100:
|
47
|
+
repeated_text += " " + text
|
28
48
|
return repeated_text
|
29
49
|
|
30
50
|
|
@@ -52,7 +72,8 @@ def flesch(target: str, **kwargs) -> float:
|
|
52
72
|
Returns:
|
53
73
|
The Flesch score.
|
54
74
|
"""
|
55
|
-
|
75
|
+
if not target.strip(): # Check if the target text is blank
|
76
|
+
return None
|
56
77
|
return get_readability(target).flesch().score
|
57
78
|
|
58
79
|
|
@@ -66,5 +87,6 @@ def gunning_fog(target: str, **kwargs) -> float:
|
|
66
87
|
Returns:
|
67
88
|
The Gunning-Fog score.
|
68
89
|
"""
|
69
|
-
|
90
|
+
if not target.strip(): # Check if the target text is blank
|
91
|
+
return None
|
70
92
|
return get_readability(target).gunning_fog().score
|
janus/prompts/prompt.py
CHANGED
@@ -2,12 +2,12 @@ import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
-
from langchain import PromptTemplate
|
6
5
|
from langchain.prompts import ChatPromptTemplate
|
7
6
|
from langchain.prompts.chat import (
|
8
7
|
HumanMessagePromptTemplate,
|
9
8
|
SystemMessagePromptTemplate,
|
10
9
|
)
|
10
|
+
from langchain_core.prompts import PromptTemplate
|
11
11
|
|
12
12
|
from ..utils.enums import LANGUAGES
|
13
13
|
from ..utils.logger import create_logger
|
@@ -40,7 +40,7 @@ retry_with_output_prompt_text = """Instructions:
|
|
40
40
|
--------------
|
41
41
|
Completion:
|
42
42
|
--------------
|
43
|
-
{
|
43
|
+
{completion}
|
44
44
|
--------------
|
45
45
|
|
46
46
|
Above, the Completion did not satisfy the constraints given in the Instructions.
|
@@ -54,13 +54,23 @@ constraints laid out in the Instructions:"""
|
|
54
54
|
|
55
55
|
|
56
56
|
retry_with_error_and_output_prompt_text = """Prompt:
|
57
|
+
--------------
|
57
58
|
{prompt}
|
59
|
+
--------------
|
58
60
|
Completion:
|
59
|
-
|
61
|
+
--------------
|
62
|
+
{completion}
|
63
|
+
--------------
|
60
64
|
|
61
65
|
Above, the Completion did not satisfy the constraints given in the Prompt.
|
62
|
-
|
63
|
-
|
66
|
+
Error:
|
67
|
+
--------------
|
68
|
+
{error}
|
69
|
+
--------------
|
70
|
+
|
71
|
+
Please try again. Please only respond with an answer that satisfies the
|
72
|
+
constraints laid out in the Prompt:"""
|
73
|
+
|
64
74
|
|
65
75
|
retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
|
66
76
|
retry_with_error_and_output_prompt = PromptTemplate.from_template(
|
@@ -74,9 +84,9 @@ class PromptEngine(ABC):
|
|
74
84
|
def __init__(
|
75
85
|
self,
|
76
86
|
source_language: str,
|
77
|
-
target_language: str,
|
78
|
-
target_version: str,
|
79
87
|
prompt_template: str,
|
88
|
+
target_language: str | None = None,
|
89
|
+
target_version: str | None = None,
|
80
90
|
) -> None:
|
81
91
|
"""Initialize a PromptEngine instance.
|
82
92
|
|
@@ -97,15 +107,18 @@ class PromptEngine(ABC):
|
|
97
107
|
|
98
108
|
# Define variables to be passed in to the prompt formatter
|
99
109
|
source_language = source_language.lower()
|
100
|
-
target_language = target_language.lower()
|
101
110
|
self.variables = dict(
|
102
111
|
SOURCE_LANGUAGE=source_language,
|
103
|
-
TARGET_LANGUAGE=target_language,
|
104
|
-
TARGET_LANGUAGE_VERSION=str(target_version),
|
105
112
|
FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
|
106
113
|
SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
|
107
|
-
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
108
114
|
)
|
115
|
+
if target_language is not None:
|
116
|
+
target_language = target_language.lower()
|
117
|
+
self.variables.update(
|
118
|
+
TARGET_LANGUAGE=target_language,
|
119
|
+
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
120
|
+
)
|
121
|
+
self.variables.update(TARGET_LANGUAGE_VERSION=str(target_version))
|
109
122
|
variables_path = template_path / PROMPT_VARIABLES_FILENAME
|
110
123
|
if variables_path.exists():
|
111
124
|
self.variables.update(json.loads(variables_path.read_text()))
|
@@ -253,3 +266,16 @@ class CoherePromptEngine(PromptEngine):
|
|
253
266
|
f"{human_prompt}"
|
254
267
|
f"<|END_OF_TURN_TOKEN|>"
|
255
268
|
)
|
269
|
+
|
270
|
+
|
271
|
+
class MistralPromptEngine(PromptEngine):
|
272
|
+
def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
|
273
|
+
system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
|
274
|
+
system_prompt = system_prompt_path.read_text()
|
275
|
+
|
276
|
+
human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
|
277
|
+
human_prompt = human_prompt_path.read_text()
|
278
|
+
|
279
|
+
return PromptTemplate.from_template(
|
280
|
+
f"<s>[INST] {system_prompt} [/INST] </s>[INST] {human_prompt} [/INST]"
|
281
|
+
)
|
@@ -1,11 +1,18 @@
|
|
1
|
-
janus/__init__.py,sha256=
|
2
|
-
janus/__main__.py,sha256=
|
1
|
+
janus/__init__.py,sha256=OkY9msOgEgU_jai5YdGjfv6hQG6UopI7S2J_M9EKs-U,351
|
2
|
+
janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
|
3
3
|
janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
|
5
|
-
janus/_tests/test_cli.py,sha256=
|
6
|
-
janus/
|
7
|
-
janus/
|
8
|
-
janus/converter.py,sha256=
|
5
|
+
janus/_tests/test_cli.py,sha256=mi7wAWV07ZFli5nQdExRGIGA3AMFD9s39-HcmDV4B6Y,4232
|
6
|
+
janus/cli.py,sha256=-aeg8R6CobK2EG_BPoZgBy_x1d6G9gp-KKKhnLMepo4,29541
|
7
|
+
janus/converter/__init__.py,sha256=kzVmWOPXRDayqqBZ8ZDaFQzA_q8PEdv407dc-DefPxY,255
|
8
|
+
janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
janus/converter/_tests/test_translate.py,sha256=eiLbmouokZrAeAYmdoJgnlx5-k4QiO6i0N6e6ZvZsvM,15885
|
10
|
+
janus/converter/converter.py,sha256=Bq07_9N_3Dv9NBqVACvb7LC2HxdQmfVZ1b0BlWrxjgo,23521
|
11
|
+
janus/converter/diagram.py,sha256=v-3ZZ4t1q74lDOjF2N6NRPkC3IK-sjLDn5_VChZTEGA,4608
|
12
|
+
janus/converter/document.py,sha256=hsW512veNjFWbdl5WriuUdNmMEqZy8ktRvqn9rRmA6E,4566
|
13
|
+
janus/converter/evaluate.py,sha256=APWQUY3gjAXqkJkPzvj0UA4wPK3Cv9QSJLM-YK9t-ng,476
|
14
|
+
janus/converter/requirements.py,sha256=6YvrJRVH9BuPCOPxnXmaJQFYmoLYYvCu3zTntDLHeNg,1832
|
15
|
+
janus/converter/translate.py,sha256=kMlGUiBYGQBXSxwX5in3CUyUifPM95wynCaRMxSDxMw,4238
|
9
16
|
janus/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
17
|
janus/embedding/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
18
|
janus/embedding/_tests/test_collections.py,sha256=eT0cYv-qmPrHJRjDZqWPFTkqVzFDRoPrRKR__FPiz58,2651
|
@@ -28,8 +35,8 @@ janus/language/binary/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
|
|
28
35
|
janus/language/binary/_tests/test_binary.py,sha256=a-8RSfKA23UrJC9c1xPQK792XZCz8npCHI7isN2dAP8,1727
|
29
36
|
janus/language/binary/binary.py,sha256=CS1RAieN8klSsCeXQEFYKUWioatUX-sOPXKQr5S6NzE,6534
|
30
37
|
janus/language/binary/reveng/decompile_script.py,sha256=veW51oJzuO-4UD3Er062jXZ_FYtTFo9OCkl82Z2xr6A,2182
|
31
|
-
janus/language/block.py,sha256=
|
32
|
-
janus/language/combine.py,sha256=
|
38
|
+
janus/language/block.py,sha256=57hfOY-KSVMioKhkCvfDtovQt4h8lCg9cJbRF7ddV1s,9280
|
39
|
+
janus/language/combine.py,sha256=e7j8zQO_D3_LElaVCsGgtnzia7aFFK56m-mhArQBlR0,2908
|
33
40
|
janus/language/file.py,sha256=X2MYcAMlCABK77uhMdI_J2foXLrqEdinapYRfLPyKB8,563
|
34
41
|
janus/language/mumps/__init__.py,sha256=-Ou_wJ-JgHezfp1dub2_qCYNiK9wO-zo2MlqxM9qiwE,48
|
35
42
|
janus/language/mumps/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -47,10 +54,10 @@ janus/language/splitter.py,sha256=4XAe0hXka7njS30UHGCngJzDgHxn3lygUjikSHuV7Xo,16
|
|
47
54
|
janus/language/treesitter/__init__.py,sha256=mUliw7ZJLZ8NkJKyUQMSoUV82hYXE0HvLHrEdGPJF4Q,43
|
48
55
|
janus/language/treesitter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
56
|
janus/language/treesitter/_tests/test_treesitter.py,sha256=nsavUV0aI6cpT9FkQve58eTTehLyQG6qJJBGlNa_bIw,2170
|
50
|
-
janus/language/treesitter/treesitter.py,sha256=
|
57
|
+
janus/language/treesitter/treesitter.py,sha256=UiV4OuWTt6IwMohHSw4FHsVNA_zxr9lNk4_Du09APdo,7509
|
51
58
|
janus/llm/__init__.py,sha256=8Pzn3Jdx867PzDc4xmwm8wvJDGzWSIhpN0NCEYFe0LQ,36
|
52
|
-
janus/llm/model_callbacks.py,sha256=
|
53
|
-
janus/llm/models_info.py,sha256=
|
59
|
+
janus/llm/model_callbacks.py,sha256=h_xlBAHRx-gxQBBjVKRpGXxdxYf6d9L6kBoXjbEAEdI,7106
|
60
|
+
janus/llm/models_info.py,sha256=B9Dn5mHc43OeZe5mHFj5wuhO194XHCTwShAa2ybnPyY,7688
|
54
61
|
janus/metrics/__init__.py,sha256=AsxtZJUzZiXJPr2ehPPltuYP-ddechjg6X85WZUO7mA,241
|
55
62
|
janus/metrics/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
63
|
janus/metrics/_tests/reference.py,sha256=hiaJPP9CXkvFBV_wL-gOe_BzELTw0nvB6uCxhxtIiE8,13
|
@@ -70,7 +77,7 @@ janus/metrics/complexity_metrics.py,sha256=1Z9n0o_CrILqayk40wRkjR1f7yvHIsJG38DxA
|
|
70
77
|
janus/metrics/file_pairing.py,sha256=WNHRV1D8GOJMq8Pla5SPkTDAT7yVaS4-UU0XIGKvEVs,3729
|
71
78
|
janus/metrics/llm_metrics.py,sha256=3677S6GYcoVcokpmAN-fwvNu-lYWAKd7M5mebiE6RZc,5687
|
72
79
|
janus/metrics/metric.py,sha256=Lgdtq87oJ-kWC_6jdPQ6-d1MqoeTnhkRszo6IZJV6c0,16974
|
73
|
-
janus/metrics/reading.py,sha256=
|
80
|
+
janus/metrics/reading.py,sha256=srLb2MO-vZL5ccRjaHz-dA4MwAvXVNyIKnOrvJXg77E,2244
|
74
81
|
janus/metrics/rouge_score.py,sha256=HfUJwUWI-yq5pOjML2ee4QTOMl0NQahnqEY2Mt8Dtnw,2865
|
75
82
|
janus/metrics/similarity.py,sha256=9pjWWpLKCsk0QfFfSgQNdPXiisqi7WJYOOHaiT8S0iY,1613
|
76
83
|
janus/metrics/splitting.py,sha256=610ScHRvALwdkqA6YyGI-tr3a18_cUofldBxGYX0SwE,968
|
@@ -82,8 +89,7 @@ janus/parsers/doc_parser.py,sha256=X8eCb1QXbL6sVWLEFGjsPyxrpJ9XnOPg7G4KZSo9A9E,5
|
|
82
89
|
janus/parsers/eval_parser.py,sha256=HB5-zY_Jpmkj6FDbuNCCVCRxwmzhViSAjPKbyyC0Ebc,2723
|
83
90
|
janus/parsers/reqs_parser.py,sha256=MFBvtR3otpyPZlkZxu0dVH1YeEJhvhNzhaGKGHaQVHA,2359
|
84
91
|
janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
85
|
-
janus/prompts/prompt.py,sha256=
|
86
|
-
janus/translate.py,sha256=bIrvyFBXUH1Cf8M-h-qSybFe0NQwuCA38heiV2toP8w,38958
|
92
|
+
janus/prompts/prompt.py,sha256=vd7UbitF0VFCi21RsggDebD51xcuyls_lQLGKkphfI8,10578
|
87
93
|
janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
88
94
|
janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
89
95
|
janus/utils/_tests/test_logger.py,sha256=4jZFm8LX828Dt9lOjiFHZIPbxYy_hHaswyrMPkscgdM,2199
|
@@ -91,8 +97,8 @@ janus/utils/_tests/test_progress.py,sha256=Yh5NDNq-24n2nhHHbJm39pENAH70PYnh9ymwd
|
|
91
97
|
janus/utils/enums.py,sha256=AoilbdiYyMvY2Mp0AM4xlbLSELfut2XMwhIM1S_msP4,27610
|
92
98
|
janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
|
93
99
|
janus/utils/progress.py,sha256=pKcCzO9JOU9fSD7qTmLWcqY5smc8mujqQMXoPgqNysE,1458
|
94
|
-
janus_llm-
|
95
|
-
janus_llm-
|
96
|
-
janus_llm-
|
97
|
-
janus_llm-
|
98
|
-
janus_llm-
|
100
|
+
janus_llm-3.0.1.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
|
101
|
+
janus_llm-3.0.1.dist-info/METADATA,sha256=tVkX6eswouFez-ASfy-iWJk_iXptEeylGETEZF4MeeI,4184
|
102
|
+
janus_llm-3.0.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
103
|
+
janus_llm-3.0.1.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
|
104
|
+
janus_llm-3.0.1.dist-info/RECORD,,
|
janus/converter.py
DELETED
@@ -1,161 +0,0 @@
|
|
1
|
-
import functools
|
2
|
-
from typing import Any
|
3
|
-
|
4
|
-
from langchain.schema.language_model import BaseLanguageModel
|
5
|
-
|
6
|
-
from .language.alc.alc import AlcSplitter
|
7
|
-
from .language.binary import BinarySplitter
|
8
|
-
from .language.mumps import MumpsSplitter
|
9
|
-
from .language.splitter import Splitter
|
10
|
-
from .language.treesitter import TreeSitterSplitter
|
11
|
-
from .utils.enums import CUSTOM_SPLITTERS, LANGUAGES
|
12
|
-
from .utils.logger import create_logger
|
13
|
-
|
14
|
-
log = create_logger(__name__)
|
15
|
-
|
16
|
-
|
17
|
-
def run_if_changed(*tracked_vars):
|
18
|
-
"""Wrapper to skip function calls if the given instance attributes haven't
|
19
|
-
been updated. Requires the _changed_attrs set to exist, and the __setattr__
|
20
|
-
method to be overridden to track parameter updates in _changed_attrs.
|
21
|
-
"""
|
22
|
-
|
23
|
-
def wrapper(func):
|
24
|
-
@functools.wraps(func)
|
25
|
-
def wrapped(self, *args, **kwargs):
|
26
|
-
# If there is overlap between the tracked variables and the changed
|
27
|
-
# ones, then call the function as normal
|
28
|
-
if self._changed_attrs.intersection(tracked_vars):
|
29
|
-
func(self, *args, **kwargs)
|
30
|
-
|
31
|
-
return wrapped
|
32
|
-
|
33
|
-
return wrapper
|
34
|
-
|
35
|
-
|
36
|
-
class Converter:
|
37
|
-
"""Parent class that converts code into something else.
|
38
|
-
|
39
|
-
Children will determine what the code gets converted into. Whether that's translated
|
40
|
-
into another language, into pseudocode, requirements, documentation, etc., or
|
41
|
-
converted into embeddings
|
42
|
-
"""
|
43
|
-
|
44
|
-
def __init__(
|
45
|
-
self,
|
46
|
-
source_language: str = "fortran",
|
47
|
-
max_tokens: None | int = None,
|
48
|
-
protected_node_types: set[str] | list[str] | tuple[str] = (),
|
49
|
-
prune_node_types: set[str] | list[str] | tuple[str] = (),
|
50
|
-
) -> None:
|
51
|
-
"""Initialize a Converter instance.
|
52
|
-
|
53
|
-
Arguments:
|
54
|
-
source_language: The source programming language.
|
55
|
-
parser_type: The type of parser to use for parsing the LLM output. Valid
|
56
|
-
values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
|
57
|
-
the `Converter` assumes you won't be parsing an output (i.e., adding to an
|
58
|
-
embedding DB).
|
59
|
-
"""
|
60
|
-
self._changed_attrs: set = set()
|
61
|
-
|
62
|
-
self._source_language: None | str
|
63
|
-
self._source_glob: None | str
|
64
|
-
self._protected_node_types: tuple[str] = ()
|
65
|
-
self._prune_node_types: tuple[str] = ()
|
66
|
-
self._splitter: None | Splitter
|
67
|
-
self._llm: None | BaseLanguageModel = None
|
68
|
-
self._max_tokens: None | int = max_tokens
|
69
|
-
|
70
|
-
self.set_source_language(source_language)
|
71
|
-
self.set_protected_node_types(protected_node_types)
|
72
|
-
self.set_prune_node_types(prune_node_types)
|
73
|
-
|
74
|
-
# Child class must call this. Should we enforce somehow?
|
75
|
-
# self._load_parameters()
|
76
|
-
|
77
|
-
def __setattr__(self, key: Any, value: Any) -> None:
|
78
|
-
if hasattr(self, "_changed_attrs"):
|
79
|
-
if not hasattr(self, key) or getattr(self, key) != value:
|
80
|
-
self._changed_attrs.add(key)
|
81
|
-
# Avoid infinite recursion
|
82
|
-
elif key != "_changed_attrs":
|
83
|
-
self._changed_attrs = set()
|
84
|
-
super().__setattr__(key, value)
|
85
|
-
|
86
|
-
def _load_parameters(self) -> None:
|
87
|
-
self._load_splitter()
|
88
|
-
self._changed_attrs.clear()
|
89
|
-
|
90
|
-
def set_source_language(self, source_language: str) -> None:
|
91
|
-
"""Validate and set the source language.
|
92
|
-
|
93
|
-
The affected objects will not be updated until _load_parameters() is called.
|
94
|
-
|
95
|
-
Arguments:
|
96
|
-
source_language: The source programming language.
|
97
|
-
"""
|
98
|
-
source_language = source_language.lower()
|
99
|
-
if source_language not in LANGUAGES:
|
100
|
-
raise ValueError(
|
101
|
-
f"Invalid source language: {source_language}. "
|
102
|
-
"Valid source languages are found in `janus.utils.enums.LANGUAGES`."
|
103
|
-
)
|
104
|
-
|
105
|
-
self._source_glob = f"**/*.{LANGUAGES[source_language]['suffix']}"
|
106
|
-
self._source_language = source_language
|
107
|
-
|
108
|
-
def set_protected_node_types(
|
109
|
-
self, protected_node_types: set[str] | list[str] | tuple[str]
|
110
|
-
) -> None:
|
111
|
-
"""Set the protected (non-mergeable) node types. This will often be structures
|
112
|
-
like functions, classes, or modules which you might want to keep separate
|
113
|
-
|
114
|
-
The affected objects will not be updated until _load_parameters() is called.
|
115
|
-
|
116
|
-
Arguments:
|
117
|
-
protected_node_types: A set of node types that aren't to be merged
|
118
|
-
"""
|
119
|
-
self._protected_node_types = tuple(set(protected_node_types or []))
|
120
|
-
|
121
|
-
def set_prune_node_types(
|
122
|
-
self, prune_node_types: set[str] | list[str] | tuple[str]
|
123
|
-
) -> None:
|
124
|
-
"""Set the node types to prune. This will often be structures
|
125
|
-
like comments or whitespace which you might want to keep out of the LLM
|
126
|
-
|
127
|
-
The affected objects will not be updated until _load_parameters() is called.
|
128
|
-
|
129
|
-
Arguments:
|
130
|
-
prune_node_types: A set of node types which should be pruned
|
131
|
-
"""
|
132
|
-
self._prune_node_types = tuple(set(prune_node_types or []))
|
133
|
-
|
134
|
-
@run_if_changed(
|
135
|
-
"_source_language",
|
136
|
-
"_max_tokens",
|
137
|
-
"_llm",
|
138
|
-
"_protected_node_types",
|
139
|
-
"_prune_node_types",
|
140
|
-
)
|
141
|
-
def _load_splitter(self) -> None:
|
142
|
-
"""Load the splitter according to this instance's attributes.
|
143
|
-
|
144
|
-
If the relevant fields have not been changed since the last time this method was
|
145
|
-
called, nothing happens.
|
146
|
-
"""
|
147
|
-
kwargs = dict(
|
148
|
-
max_tokens=self._max_tokens,
|
149
|
-
model=self._llm,
|
150
|
-
protected_node_types=self._protected_node_types,
|
151
|
-
prune_node_types=self._prune_node_types,
|
152
|
-
)
|
153
|
-
if self._source_language in CUSTOM_SPLITTERS:
|
154
|
-
if self._source_language == "mumps":
|
155
|
-
self._splitter = MumpsSplitter(**kwargs)
|
156
|
-
elif self._source_language == "ibmhlasm":
|
157
|
-
self._splitter = AlcSplitter(**kwargs)
|
158
|
-
elif self._source_language == "binary":
|
159
|
-
self._splitter = BinarySplitter(**kwargs)
|
160
|
-
else:
|
161
|
-
self._splitter = TreeSitterSplitter(language=self._source_language, **kwargs)
|