janus-llm 2.1.0__py3-none-any.whl → 3.0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +2 -2
- janus/__main__.py +1 -1
- janus/_tests/test_cli.py +1 -2
- janus/cli.py +43 -50
- janus/converter/__init__.py +6 -0
- janus/converter/_tests/__init__.py +0 -0
- janus/{_tests → converter/_tests}/test_translate.py +11 -22
- janus/converter/converter.py +614 -0
- janus/converter/diagram.py +124 -0
- janus/converter/document.py +131 -0
- janus/converter/evaluate.py +15 -0
- janus/converter/requirements.py +51 -0
- janus/converter/translate.py +108 -0
- janus/language/block.py +1 -1
- janus/language/combine.py +0 -1
- janus/language/treesitter/treesitter.py +20 -1
- janus/llm/model_callbacks.py +33 -36
- janus/llm/models_info.py +14 -0
- janus/metrics/reading.py +27 -5
- janus/prompts/prompt.py +37 -11
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/METADATA +1 -1
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/RECORD +25 -19
- janus/converter.py +0 -161
- janus/translate.py +0 -987
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/WHEEL +0 -0
- {janus_llm-2.1.0.dist-info → janus_llm-3.0.1.dist-info}/entry_points.txt +0 -0
janus/metrics/reading.py
CHANGED
@@ -1,9 +1,30 @@
|
|
1
|
+
import re
|
2
|
+
|
1
3
|
import nltk
|
2
4
|
import readability
|
5
|
+
from nltk.tokenize import TweetTokenizer
|
3
6
|
|
4
7
|
from .metric import metric
|
5
8
|
|
6
9
|
|
10
|
+
def word_count(text):
|
11
|
+
"""Calculates word count exactly how readability package does
|
12
|
+
|
13
|
+
Arguments:
|
14
|
+
text: The input string.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
Word Count
|
18
|
+
"""
|
19
|
+
tokenizer = TweetTokenizer()
|
20
|
+
word_count = 0
|
21
|
+
tokens = tokenizer.tokenize(text)
|
22
|
+
for t in tokens:
|
23
|
+
if not re.match(r"^[.,\/#!$%'\^&\*;:{}=\-_`~()]+$", t):
|
24
|
+
word_count += 1
|
25
|
+
return word_count
|
26
|
+
|
27
|
+
|
7
28
|
def _repeat_text(text):
|
8
29
|
"""Repeats a string until its length is over 100 words.
|
9
30
|
|
@@ -20,11 +41,10 @@ def _repeat_text(text):
|
|
20
41
|
if not text.endswith("."):
|
21
42
|
text += "." # Add a period if missing
|
22
43
|
|
23
|
-
# Check if repeated text is long enough, repeat more if needed
|
24
44
|
repeated_text = text
|
25
|
-
while len(repeated_text.split()) < 100:
|
26
|
-
repeated_text += " " + text
|
27
45
|
|
46
|
+
while word_count(repeated_text) < 100:
|
47
|
+
repeated_text += " " + text
|
28
48
|
return repeated_text
|
29
49
|
|
30
50
|
|
@@ -52,7 +72,8 @@ def flesch(target: str, **kwargs) -> float:
|
|
52
72
|
Returns:
|
53
73
|
The Flesch score.
|
54
74
|
"""
|
55
|
-
|
75
|
+
if not target.strip(): # Check if the target text is blank
|
76
|
+
return None
|
56
77
|
return get_readability(target).flesch().score
|
57
78
|
|
58
79
|
|
@@ -66,5 +87,6 @@ def gunning_fog(target: str, **kwargs) -> float:
|
|
66
87
|
Returns:
|
67
88
|
The Gunning-Fog score.
|
68
89
|
"""
|
69
|
-
|
90
|
+
if not target.strip(): # Check if the target text is blank
|
91
|
+
return None
|
70
92
|
return get_readability(target).gunning_fog().score
|
janus/prompts/prompt.py
CHANGED
@@ -2,12 +2,12 @@ import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
-
from langchain import PromptTemplate
|
6
5
|
from langchain.prompts import ChatPromptTemplate
|
7
6
|
from langchain.prompts.chat import (
|
8
7
|
HumanMessagePromptTemplate,
|
9
8
|
SystemMessagePromptTemplate,
|
10
9
|
)
|
10
|
+
from langchain_core.prompts import PromptTemplate
|
11
11
|
|
12
12
|
from ..utils.enums import LANGUAGES
|
13
13
|
from ..utils.logger import create_logger
|
@@ -40,7 +40,7 @@ retry_with_output_prompt_text = """Instructions:
|
|
40
40
|
--------------
|
41
41
|
Completion:
|
42
42
|
--------------
|
43
|
-
{
|
43
|
+
{completion}
|
44
44
|
--------------
|
45
45
|
|
46
46
|
Above, the Completion did not satisfy the constraints given in the Instructions.
|
@@ -54,13 +54,23 @@ constraints laid out in the Instructions:"""
|
|
54
54
|
|
55
55
|
|
56
56
|
retry_with_error_and_output_prompt_text = """Prompt:
|
57
|
+
--------------
|
57
58
|
{prompt}
|
59
|
+
--------------
|
58
60
|
Completion:
|
59
|
-
|
61
|
+
--------------
|
62
|
+
{completion}
|
63
|
+
--------------
|
60
64
|
|
61
65
|
Above, the Completion did not satisfy the constraints given in the Prompt.
|
62
|
-
|
63
|
-
|
66
|
+
Error:
|
67
|
+
--------------
|
68
|
+
{error}
|
69
|
+
--------------
|
70
|
+
|
71
|
+
Please try again. Please only respond with an answer that satisfies the
|
72
|
+
constraints laid out in the Prompt:"""
|
73
|
+
|
64
74
|
|
65
75
|
retry_with_output_prompt = PromptTemplate.from_template(retry_with_output_prompt_text)
|
66
76
|
retry_with_error_and_output_prompt = PromptTemplate.from_template(
|
@@ -74,9 +84,9 @@ class PromptEngine(ABC):
|
|
74
84
|
def __init__(
|
75
85
|
self,
|
76
86
|
source_language: str,
|
77
|
-
target_language: str,
|
78
|
-
target_version: str,
|
79
87
|
prompt_template: str,
|
88
|
+
target_language: str | None = None,
|
89
|
+
target_version: str | None = None,
|
80
90
|
) -> None:
|
81
91
|
"""Initialize a PromptEngine instance.
|
82
92
|
|
@@ -97,15 +107,18 @@ class PromptEngine(ABC):
|
|
97
107
|
|
98
108
|
# Define variables to be passed in to the prompt formatter
|
99
109
|
source_language = source_language.lower()
|
100
|
-
target_language = target_language.lower()
|
101
110
|
self.variables = dict(
|
102
111
|
SOURCE_LANGUAGE=source_language,
|
103
|
-
TARGET_LANGUAGE=target_language,
|
104
|
-
TARGET_LANGUAGE_VERSION=str(target_version),
|
105
112
|
FILE_SUFFIX=LANGUAGES[source_language]["suffix"],
|
106
113
|
SOURCE_CODE_EXAMPLE=LANGUAGES[source_language]["example"],
|
107
|
-
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
108
114
|
)
|
115
|
+
if target_language is not None:
|
116
|
+
target_language = target_language.lower()
|
117
|
+
self.variables.update(
|
118
|
+
TARGET_LANGUAGE=target_language,
|
119
|
+
TARGET_CODE_EXAMPLE=LANGUAGES[target_language]["example"],
|
120
|
+
)
|
121
|
+
self.variables.update(TARGET_LANGUAGE_VERSION=str(target_version))
|
109
122
|
variables_path = template_path / PROMPT_VARIABLES_FILENAME
|
110
123
|
if variables_path.exists():
|
111
124
|
self.variables.update(json.loads(variables_path.read_text()))
|
@@ -253,3 +266,16 @@ class CoherePromptEngine(PromptEngine):
|
|
253
266
|
f"{human_prompt}"
|
254
267
|
f"<|END_OF_TURN_TOKEN|>"
|
255
268
|
)
|
269
|
+
|
270
|
+
|
271
|
+
class MistralPromptEngine(PromptEngine):
|
272
|
+
def load_prompt_template(self, template_path: Path) -> ChatPromptTemplate:
|
273
|
+
system_prompt_path = template_path / SYSTEM_PROMPT_TEMPLATE_FILENAME
|
274
|
+
system_prompt = system_prompt_path.read_text()
|
275
|
+
|
276
|
+
human_prompt_path = template_path / HUMAN_PROMPT_TEMPLATE_FILENAME
|
277
|
+
human_prompt = human_prompt_path.read_text()
|
278
|
+
|
279
|
+
return PromptTemplate.from_template(
|
280
|
+
f"<s>[INST] {system_prompt} [/INST] </s>[INST] {human_prompt} [/INST]"
|
281
|
+
)
|
@@ -1,11 +1,18 @@
|
|
1
|
-
janus/__init__.py,sha256=
|
2
|
-
janus/__main__.py,sha256=
|
1
|
+
janus/__init__.py,sha256=OkY9msOgEgU_jai5YdGjfv6hQG6UopI7S2J_M9EKs-U,351
|
2
|
+
janus/__main__.py,sha256=lEkpNtLVPtFo8ySDZeXJ_NXDHb0GVdZFPWB4gD4RPS8,64
|
3
3
|
janus/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
janus/_tests/conftest.py,sha256=V7uW-oq3YbFiRPvrq15YoVVrA1n_83pjgiyTZ-IUGW8,963
|
5
|
-
janus/_tests/test_cli.py,sha256=
|
6
|
-
janus/
|
7
|
-
janus/
|
8
|
-
janus/converter.py,sha256=
|
5
|
+
janus/_tests/test_cli.py,sha256=mi7wAWV07ZFli5nQdExRGIGA3AMFD9s39-HcmDV4B6Y,4232
|
6
|
+
janus/cli.py,sha256=-aeg8R6CobK2EG_BPoZgBy_x1d6G9gp-KKKhnLMepo4,29541
|
7
|
+
janus/converter/__init__.py,sha256=kzVmWOPXRDayqqBZ8ZDaFQzA_q8PEdv407dc-DefPxY,255
|
8
|
+
janus/converter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
janus/converter/_tests/test_translate.py,sha256=eiLbmouokZrAeAYmdoJgnlx5-k4QiO6i0N6e6ZvZsvM,15885
|
10
|
+
janus/converter/converter.py,sha256=Bq07_9N_3Dv9NBqVACvb7LC2HxdQmfVZ1b0BlWrxjgo,23521
|
11
|
+
janus/converter/diagram.py,sha256=v-3ZZ4t1q74lDOjF2N6NRPkC3IK-sjLDn5_VChZTEGA,4608
|
12
|
+
janus/converter/document.py,sha256=hsW512veNjFWbdl5WriuUdNmMEqZy8ktRvqn9rRmA6E,4566
|
13
|
+
janus/converter/evaluate.py,sha256=APWQUY3gjAXqkJkPzvj0UA4wPK3Cv9QSJLM-YK9t-ng,476
|
14
|
+
janus/converter/requirements.py,sha256=6YvrJRVH9BuPCOPxnXmaJQFYmoLYYvCu3zTntDLHeNg,1832
|
15
|
+
janus/converter/translate.py,sha256=kMlGUiBYGQBXSxwX5in3CUyUifPM95wynCaRMxSDxMw,4238
|
9
16
|
janus/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
17
|
janus/embedding/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
18
|
janus/embedding/_tests/test_collections.py,sha256=eT0cYv-qmPrHJRjDZqWPFTkqVzFDRoPrRKR__FPiz58,2651
|
@@ -28,8 +35,8 @@ janus/language/binary/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
|
|
28
35
|
janus/language/binary/_tests/test_binary.py,sha256=a-8RSfKA23UrJC9c1xPQK792XZCz8npCHI7isN2dAP8,1727
|
29
36
|
janus/language/binary/binary.py,sha256=CS1RAieN8klSsCeXQEFYKUWioatUX-sOPXKQr5S6NzE,6534
|
30
37
|
janus/language/binary/reveng/decompile_script.py,sha256=veW51oJzuO-4UD3Er062jXZ_FYtTFo9OCkl82Z2xr6A,2182
|
31
|
-
janus/language/block.py,sha256=
|
32
|
-
janus/language/combine.py,sha256=
|
38
|
+
janus/language/block.py,sha256=57hfOY-KSVMioKhkCvfDtovQt4h8lCg9cJbRF7ddV1s,9280
|
39
|
+
janus/language/combine.py,sha256=e7j8zQO_D3_LElaVCsGgtnzia7aFFK56m-mhArQBlR0,2908
|
33
40
|
janus/language/file.py,sha256=X2MYcAMlCABK77uhMdI_J2foXLrqEdinapYRfLPyKB8,563
|
34
41
|
janus/language/mumps/__init__.py,sha256=-Ou_wJ-JgHezfp1dub2_qCYNiK9wO-zo2MlqxM9qiwE,48
|
35
42
|
janus/language/mumps/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -47,10 +54,10 @@ janus/language/splitter.py,sha256=4XAe0hXka7njS30UHGCngJzDgHxn3lygUjikSHuV7Xo,16
|
|
47
54
|
janus/language/treesitter/__init__.py,sha256=mUliw7ZJLZ8NkJKyUQMSoUV82hYXE0HvLHrEdGPJF4Q,43
|
48
55
|
janus/language/treesitter/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
56
|
janus/language/treesitter/_tests/test_treesitter.py,sha256=nsavUV0aI6cpT9FkQve58eTTehLyQG6qJJBGlNa_bIw,2170
|
50
|
-
janus/language/treesitter/treesitter.py,sha256=
|
57
|
+
janus/language/treesitter/treesitter.py,sha256=UiV4OuWTt6IwMohHSw4FHsVNA_zxr9lNk4_Du09APdo,7509
|
51
58
|
janus/llm/__init__.py,sha256=8Pzn3Jdx867PzDc4xmwm8wvJDGzWSIhpN0NCEYFe0LQ,36
|
52
|
-
janus/llm/model_callbacks.py,sha256=
|
53
|
-
janus/llm/models_info.py,sha256=
|
59
|
+
janus/llm/model_callbacks.py,sha256=h_xlBAHRx-gxQBBjVKRpGXxdxYf6d9L6kBoXjbEAEdI,7106
|
60
|
+
janus/llm/models_info.py,sha256=B9Dn5mHc43OeZe5mHFj5wuhO194XHCTwShAa2ybnPyY,7688
|
54
61
|
janus/metrics/__init__.py,sha256=AsxtZJUzZiXJPr2ehPPltuYP-ddechjg6X85WZUO7mA,241
|
55
62
|
janus/metrics/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
63
|
janus/metrics/_tests/reference.py,sha256=hiaJPP9CXkvFBV_wL-gOe_BzELTw0nvB6uCxhxtIiE8,13
|
@@ -70,7 +77,7 @@ janus/metrics/complexity_metrics.py,sha256=1Z9n0o_CrILqayk40wRkjR1f7yvHIsJG38DxA
|
|
70
77
|
janus/metrics/file_pairing.py,sha256=WNHRV1D8GOJMq8Pla5SPkTDAT7yVaS4-UU0XIGKvEVs,3729
|
71
78
|
janus/metrics/llm_metrics.py,sha256=3677S6GYcoVcokpmAN-fwvNu-lYWAKd7M5mebiE6RZc,5687
|
72
79
|
janus/metrics/metric.py,sha256=Lgdtq87oJ-kWC_6jdPQ6-d1MqoeTnhkRszo6IZJV6c0,16974
|
73
|
-
janus/metrics/reading.py,sha256=
|
80
|
+
janus/metrics/reading.py,sha256=srLb2MO-vZL5ccRjaHz-dA4MwAvXVNyIKnOrvJXg77E,2244
|
74
81
|
janus/metrics/rouge_score.py,sha256=HfUJwUWI-yq5pOjML2ee4QTOMl0NQahnqEY2Mt8Dtnw,2865
|
75
82
|
janus/metrics/similarity.py,sha256=9pjWWpLKCsk0QfFfSgQNdPXiisqi7WJYOOHaiT8S0iY,1613
|
76
83
|
janus/metrics/splitting.py,sha256=610ScHRvALwdkqA6YyGI-tr3a18_cUofldBxGYX0SwE,968
|
@@ -82,8 +89,7 @@ janus/parsers/doc_parser.py,sha256=X8eCb1QXbL6sVWLEFGjsPyxrpJ9XnOPg7G4KZSo9A9E,5
|
|
82
89
|
janus/parsers/eval_parser.py,sha256=HB5-zY_Jpmkj6FDbuNCCVCRxwmzhViSAjPKbyyC0Ebc,2723
|
83
90
|
janus/parsers/reqs_parser.py,sha256=MFBvtR3otpyPZlkZxu0dVH1YeEJhvhNzhaGKGHaQVHA,2359
|
84
91
|
janus/prompts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
85
|
-
janus/prompts/prompt.py,sha256=
|
86
|
-
janus/translate.py,sha256=bIrvyFBXUH1Cf8M-h-qSybFe0NQwuCA38heiV2toP8w,38958
|
92
|
+
janus/prompts/prompt.py,sha256=vd7UbitF0VFCi21RsggDebD51xcuyls_lQLGKkphfI8,10578
|
87
93
|
janus/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
88
94
|
janus/utils/_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
89
95
|
janus/utils/_tests/test_logger.py,sha256=4jZFm8LX828Dt9lOjiFHZIPbxYy_hHaswyrMPkscgdM,2199
|
@@ -91,8 +97,8 @@ janus/utils/_tests/test_progress.py,sha256=Yh5NDNq-24n2nhHHbJm39pENAH70PYnh9ymwd
|
|
91
97
|
janus/utils/enums.py,sha256=AoilbdiYyMvY2Mp0AM4xlbLSELfut2XMwhIM1S_msP4,27610
|
92
98
|
janus/utils/logger.py,sha256=KZeuaMAnlSZCsj4yL0P6N-JzZwpxXygzACWfdZFeuek,2337
|
93
99
|
janus/utils/progress.py,sha256=pKcCzO9JOU9fSD7qTmLWcqY5smc8mujqQMXoPgqNysE,1458
|
94
|
-
janus_llm-
|
95
|
-
janus_llm-
|
96
|
-
janus_llm-
|
97
|
-
janus_llm-
|
98
|
-
janus_llm-
|
100
|
+
janus_llm-3.0.1.dist-info/LICENSE,sha256=_j0st0a-HB6MRbP3_BW3PUqpS16v54luyy-1zVyl8NU,10789
|
101
|
+
janus_llm-3.0.1.dist-info/METADATA,sha256=tVkX6eswouFez-ASfy-iWJk_iXptEeylGETEZF4MeeI,4184
|
102
|
+
janus_llm-3.0.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
103
|
+
janus_llm-3.0.1.dist-info/entry_points.txt,sha256=OGhQwzj6pvXp79B0SaBD5apGekCu7Dwe9fZZT_TZ544,39
|
104
|
+
janus_llm-3.0.1.dist-info/RECORD,,
|
janus/converter.py
DELETED
@@ -1,161 +0,0 @@
|
|
1
|
-
import functools
|
2
|
-
from typing import Any
|
3
|
-
|
4
|
-
from langchain.schema.language_model import BaseLanguageModel
|
5
|
-
|
6
|
-
from .language.alc.alc import AlcSplitter
|
7
|
-
from .language.binary import BinarySplitter
|
8
|
-
from .language.mumps import MumpsSplitter
|
9
|
-
from .language.splitter import Splitter
|
10
|
-
from .language.treesitter import TreeSitterSplitter
|
11
|
-
from .utils.enums import CUSTOM_SPLITTERS, LANGUAGES
|
12
|
-
from .utils.logger import create_logger
|
13
|
-
|
14
|
-
log = create_logger(__name__)
|
15
|
-
|
16
|
-
|
17
|
-
def run_if_changed(*tracked_vars):
|
18
|
-
"""Wrapper to skip function calls if the given instance attributes haven't
|
19
|
-
been updated. Requires the _changed_attrs set to exist, and the __setattr__
|
20
|
-
method to be overridden to track parameter updates in _changed_attrs.
|
21
|
-
"""
|
22
|
-
|
23
|
-
def wrapper(func):
|
24
|
-
@functools.wraps(func)
|
25
|
-
def wrapped(self, *args, **kwargs):
|
26
|
-
# If there is overlap between the tracked variables and the changed
|
27
|
-
# ones, then call the function as normal
|
28
|
-
if self._changed_attrs.intersection(tracked_vars):
|
29
|
-
func(self, *args, **kwargs)
|
30
|
-
|
31
|
-
return wrapped
|
32
|
-
|
33
|
-
return wrapper
|
34
|
-
|
35
|
-
|
36
|
-
class Converter:
|
37
|
-
"""Parent class that converts code into something else.
|
38
|
-
|
39
|
-
Children will determine what the code gets converted into. Whether that's translated
|
40
|
-
into another language, into pseudocode, requirements, documentation, etc., or
|
41
|
-
converted into embeddings
|
42
|
-
"""
|
43
|
-
|
44
|
-
def __init__(
|
45
|
-
self,
|
46
|
-
source_language: str = "fortran",
|
47
|
-
max_tokens: None | int = None,
|
48
|
-
protected_node_types: set[str] | list[str] | tuple[str] = (),
|
49
|
-
prune_node_types: set[str] | list[str] | tuple[str] = (),
|
50
|
-
) -> None:
|
51
|
-
"""Initialize a Converter instance.
|
52
|
-
|
53
|
-
Arguments:
|
54
|
-
source_language: The source programming language.
|
55
|
-
parser_type: The type of parser to use for parsing the LLM output. Valid
|
56
|
-
values are `"code"`, `"text"`, `"eval"`, and `None` (default). If `None`,
|
57
|
-
the `Converter` assumes you won't be parsing an output (i.e., adding to an
|
58
|
-
embedding DB).
|
59
|
-
"""
|
60
|
-
self._changed_attrs: set = set()
|
61
|
-
|
62
|
-
self._source_language: None | str
|
63
|
-
self._source_glob: None | str
|
64
|
-
self._protected_node_types: tuple[str] = ()
|
65
|
-
self._prune_node_types: tuple[str] = ()
|
66
|
-
self._splitter: None | Splitter
|
67
|
-
self._llm: None | BaseLanguageModel = None
|
68
|
-
self._max_tokens: None | int = max_tokens
|
69
|
-
|
70
|
-
self.set_source_language(source_language)
|
71
|
-
self.set_protected_node_types(protected_node_types)
|
72
|
-
self.set_prune_node_types(prune_node_types)
|
73
|
-
|
74
|
-
# Child class must call this. Should we enforce somehow?
|
75
|
-
# self._load_parameters()
|
76
|
-
|
77
|
-
def __setattr__(self, key: Any, value: Any) -> None:
|
78
|
-
if hasattr(self, "_changed_attrs"):
|
79
|
-
if not hasattr(self, key) or getattr(self, key) != value:
|
80
|
-
self._changed_attrs.add(key)
|
81
|
-
# Avoid infinite recursion
|
82
|
-
elif key != "_changed_attrs":
|
83
|
-
self._changed_attrs = set()
|
84
|
-
super().__setattr__(key, value)
|
85
|
-
|
86
|
-
def _load_parameters(self) -> None:
|
87
|
-
self._load_splitter()
|
88
|
-
self._changed_attrs.clear()
|
89
|
-
|
90
|
-
def set_source_language(self, source_language: str) -> None:
|
91
|
-
"""Validate and set the source language.
|
92
|
-
|
93
|
-
The affected objects will not be updated until _load_parameters() is called.
|
94
|
-
|
95
|
-
Arguments:
|
96
|
-
source_language: The source programming language.
|
97
|
-
"""
|
98
|
-
source_language = source_language.lower()
|
99
|
-
if source_language not in LANGUAGES:
|
100
|
-
raise ValueError(
|
101
|
-
f"Invalid source language: {source_language}. "
|
102
|
-
"Valid source languages are found in `janus.utils.enums.LANGUAGES`."
|
103
|
-
)
|
104
|
-
|
105
|
-
self._source_glob = f"**/*.{LANGUAGES[source_language]['suffix']}"
|
106
|
-
self._source_language = source_language
|
107
|
-
|
108
|
-
def set_protected_node_types(
|
109
|
-
self, protected_node_types: set[str] | list[str] | tuple[str]
|
110
|
-
) -> None:
|
111
|
-
"""Set the protected (non-mergeable) node types. This will often be structures
|
112
|
-
like functions, classes, or modules which you might want to keep separate
|
113
|
-
|
114
|
-
The affected objects will not be updated until _load_parameters() is called.
|
115
|
-
|
116
|
-
Arguments:
|
117
|
-
protected_node_types: A set of node types that aren't to be merged
|
118
|
-
"""
|
119
|
-
self._protected_node_types = tuple(set(protected_node_types or []))
|
120
|
-
|
121
|
-
def set_prune_node_types(
|
122
|
-
self, prune_node_types: set[str] | list[str] | tuple[str]
|
123
|
-
) -> None:
|
124
|
-
"""Set the node types to prune. This will often be structures
|
125
|
-
like comments or whitespace which you might want to keep out of the LLM
|
126
|
-
|
127
|
-
The affected objects will not be updated until _load_parameters() is called.
|
128
|
-
|
129
|
-
Arguments:
|
130
|
-
prune_node_types: A set of node types which should be pruned
|
131
|
-
"""
|
132
|
-
self._prune_node_types = tuple(set(prune_node_types or []))
|
133
|
-
|
134
|
-
@run_if_changed(
|
135
|
-
"_source_language",
|
136
|
-
"_max_tokens",
|
137
|
-
"_llm",
|
138
|
-
"_protected_node_types",
|
139
|
-
"_prune_node_types",
|
140
|
-
)
|
141
|
-
def _load_splitter(self) -> None:
|
142
|
-
"""Load the splitter according to this instance's attributes.
|
143
|
-
|
144
|
-
If the relevant fields have not been changed since the last time this method was
|
145
|
-
called, nothing happens.
|
146
|
-
"""
|
147
|
-
kwargs = dict(
|
148
|
-
max_tokens=self._max_tokens,
|
149
|
-
model=self._llm,
|
150
|
-
protected_node_types=self._protected_node_types,
|
151
|
-
prune_node_types=self._prune_node_types,
|
152
|
-
)
|
153
|
-
if self._source_language in CUSTOM_SPLITTERS:
|
154
|
-
if self._source_language == "mumps":
|
155
|
-
self._splitter = MumpsSplitter(**kwargs)
|
156
|
-
elif self._source_language == "ibmhlasm":
|
157
|
-
self._splitter = AlcSplitter(**kwargs)
|
158
|
-
elif self._source_language == "binary":
|
159
|
-
self._splitter = BinarySplitter(**kwargs)
|
160
|
-
else:
|
161
|
-
self._splitter = TreeSitterSplitter(language=self._source_language, **kwargs)
|