janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +1 -1
- janus/cli.py +286 -30
- janus/converter/__init__.py +1 -0
- janus/converter/converter.py +46 -47
- janus/converter/evaluate.py +230 -4
- janus/converter/partition.py +27 -0
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +9 -4
- janus/language/combine.py +22 -0
- janus/language/splitter.py +31 -23
- janus/language/treesitter/treesitter.py +9 -1
- janus/llm/models_info.py +20 -12
- janus/parsers/eval_parsers/incose_parser.py +134 -0
- janus/parsers/eval_parsers/inline_comment_parser.py +112 -0
- janus/parsers/partition_parser.py +168 -0
- janus/refiners/refiner.py +38 -12
- janus/refiners/uml.py +33 -0
- janus/retrievers/retriever.py +60 -0
- janus/utils/enums.py +14 -0
- janus/utils/pdf_docs_reader.py +134 -0
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/METADATA +9 -1
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/RECORD +25 -19
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/WHEEL +1 -1
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/LICENSE +0 -0
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.messages import BaseMessage
|
8
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, conint
|
9
|
+
|
10
|
+
from janus.language.block import CodeBlock
|
11
|
+
from janus.parsers.parser import JanusParser
|
12
|
+
from janus.utils.logger import create_logger
|
13
|
+
|
14
|
+
log = create_logger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class Criteria(BaseModel):
|
18
|
+
reasoning: str = Field(description="A short explanation for the given score")
|
19
|
+
# Constrained to an integer between 1 and 4
|
20
|
+
score: conint(ge=1, le=4) = Field( # type: ignore
|
21
|
+
description="An integer score between 1 and 4 (inclusive), 4 being the best"
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
class Comment(BaseModel):
|
26
|
+
comment_id: str = Field(description="The 8-character comment ID")
|
27
|
+
completeness: Criteria = Field(description="The completeness of the comment")
|
28
|
+
hallucination: Criteria = Field(description="The factualness of the comment")
|
29
|
+
readability: Criteria = Field(description="The readability of the comment")
|
30
|
+
usefulness: Criteria = Field(description="The usefulness of the comment")
|
31
|
+
|
32
|
+
|
33
|
+
class CommentList(BaseModel):
|
34
|
+
__root__: list[Comment] = Field(
|
35
|
+
description=(
|
36
|
+
"A list of inline comment evaluations. Each element should include"
|
37
|
+
" the comment's 8-character ID in the `comment_id` field, and four"
|
38
|
+
" score objects corresponding to each metric (`completeness`,"
|
39
|
+
" `hallucination`, `readability`, and `usefulness`)."
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class InlineCommentParser(JanusParser, PydanticOutputParser):
|
45
|
+
comments: dict[str, str]
|
46
|
+
|
47
|
+
def __init__(self):
|
48
|
+
PydanticOutputParser.__init__(
|
49
|
+
self,
|
50
|
+
pydantic_object=CommentList,
|
51
|
+
comments=[],
|
52
|
+
)
|
53
|
+
|
54
|
+
def parse_input(self, block: CodeBlock) -> str:
|
55
|
+
# TODO: Perform comment stripping/placeholding here rather than in script
|
56
|
+
text = super().parse_input(block)
|
57
|
+
self.comments = dict(
|
58
|
+
re.findall(
|
59
|
+
r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
|
60
|
+
text,
|
61
|
+
flags=re.MULTILINE,
|
62
|
+
)
|
63
|
+
)
|
64
|
+
return text
|
65
|
+
|
66
|
+
def parse(self, text: str | BaseMessage) -> str:
|
67
|
+
if isinstance(text, BaseMessage):
|
68
|
+
text = str(text.content)
|
69
|
+
|
70
|
+
# Strip everything outside the JSON object
|
71
|
+
begin, end = text.find("["), text.rfind("]")
|
72
|
+
text = text[begin : end + 1]
|
73
|
+
|
74
|
+
try:
|
75
|
+
out: CommentList = super().parse(text)
|
76
|
+
except json.JSONDecodeError as e:
|
77
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
78
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
79
|
+
|
80
|
+
evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
|
81
|
+
|
82
|
+
seen_keys = set(evals.keys())
|
83
|
+
expected_keys = set(self.comments.keys())
|
84
|
+
missing_keys = expected_keys.difference(seen_keys)
|
85
|
+
invalid_keys = seen_keys.difference(expected_keys)
|
86
|
+
if missing_keys:
|
87
|
+
log.debug(f"Missing keys: {missing_keys}")
|
88
|
+
if invalid_keys:
|
89
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
90
|
+
log.debug(f"Missing keys: {missing_keys}")
|
91
|
+
raise OutputParserException(
|
92
|
+
f"Got invalid return object. Missing the following expected "
|
93
|
+
f"keys: {missing_keys}"
|
94
|
+
)
|
95
|
+
|
96
|
+
for key in invalid_keys:
|
97
|
+
del evals[key]
|
98
|
+
|
99
|
+
for cid in evals.keys():
|
100
|
+
evals[cid]["comment"] = self.comments[cid]
|
101
|
+
evals[cid].pop("comment_id")
|
102
|
+
|
103
|
+
return json.dumps(evals)
|
104
|
+
|
105
|
+
def parse_combined_output(self, text: str) -> str:
|
106
|
+
if not text.strip():
|
107
|
+
return str({})
|
108
|
+
objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
|
109
|
+
output_obj = {}
|
110
|
+
for obj in objs:
|
111
|
+
output_obj.update(obj)
|
112
|
+
return json.dumps(output_obj)
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.language_models import BaseLanguageModel
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class PartitionObject(BaseModel):
|
20
|
+
reasoning: str = Field(
|
21
|
+
description="An explanation for why the code should be split at this point"
|
22
|
+
)
|
23
|
+
location: str = Field(
|
24
|
+
description="The 8-character line label which should start a new chunk"
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class PartitionList(BaseModel):
|
29
|
+
__root__: list[PartitionObject] = Field(
|
30
|
+
description=(
|
31
|
+
"A list of appropriate split points, each with a `reasoning` field "
|
32
|
+
"that explains a justification for splitting the code at that point, "
|
33
|
+
"and a `location` field which is simply the 8-character line ID. "
|
34
|
+
"The `reasoning` field should always be included first."
|
35
|
+
)
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
# The following IDs appear in the prompt example. If the LLM produces them,
|
40
|
+
# they should be ignored
|
41
|
+
EXAMPLE_IDS = {
|
42
|
+
"0d2f4f8d",
|
43
|
+
"def2a953",
|
44
|
+
"75315253",
|
45
|
+
"e7f928da",
|
46
|
+
"1781b2a9",
|
47
|
+
"2fe21e27",
|
48
|
+
"9aef6179",
|
49
|
+
"6061bd82",
|
50
|
+
"22bd0c30",
|
51
|
+
"5d85e19e",
|
52
|
+
"06027969",
|
53
|
+
"91b722fb",
|
54
|
+
"4b3f79be",
|
55
|
+
"k57w964a",
|
56
|
+
"51638s96",
|
57
|
+
"065o6q32",
|
58
|
+
"j5q6p852",
|
59
|
+
}
|
60
|
+
|
61
|
+
|
62
|
+
class PartitionParser(JanusParser, PydanticOutputParser):
|
63
|
+
token_limit: int
|
64
|
+
model: BaseLanguageModel
|
65
|
+
lines: list[str] = []
|
66
|
+
line_id_to_index: dict[str, int] = {}
|
67
|
+
|
68
|
+
def __init__(self, token_limit: int, model: BaseLanguageModel):
|
69
|
+
PydanticOutputParser.__init__(
|
70
|
+
self,
|
71
|
+
pydantic_object=PartitionList,
|
72
|
+
model=model,
|
73
|
+
token_limit=token_limit,
|
74
|
+
)
|
75
|
+
|
76
|
+
def parse_input(self, block: CodeBlock) -> str:
|
77
|
+
code = str(block.text)
|
78
|
+
RNG.seed(code)
|
79
|
+
|
80
|
+
self.lines = code.split("\n")
|
81
|
+
|
82
|
+
# Generate a unique ID for each line (ensure they are unique)
|
83
|
+
line_ids = set()
|
84
|
+
while len(line_ids) < len(self.lines):
|
85
|
+
line_id = str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8]
|
86
|
+
if line_id in EXAMPLE_IDS:
|
87
|
+
continue
|
88
|
+
line_ids.add(line_id)
|
89
|
+
|
90
|
+
# Prepend each line with the corresponding ID, save the mapping
|
91
|
+
self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
|
92
|
+
processed = "\n".join(
|
93
|
+
f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
|
94
|
+
)
|
95
|
+
return processed
|
96
|
+
|
97
|
+
def parse(self, text: str | BaseMessage) -> str:
|
98
|
+
if isinstance(text, BaseMessage):
|
99
|
+
text = str(text.content)
|
100
|
+
|
101
|
+
# Strip everything outside the JSON object
|
102
|
+
begin, end = text.find("["), text.rfind("]")
|
103
|
+
text = text[begin : end + 1]
|
104
|
+
|
105
|
+
try:
|
106
|
+
out: PartitionList = super().parse(text)
|
107
|
+
except (OutputParserException, json.JSONDecodeError):
|
108
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
109
|
+
raise
|
110
|
+
|
111
|
+
# Get partition locations, discard reasoning
|
112
|
+
partition_locations = {partition.location for partition in out.__root__}
|
113
|
+
|
114
|
+
# Ignore IDs from the example input
|
115
|
+
partition_locations.difference_update(EXAMPLE_IDS)
|
116
|
+
|
117
|
+
# Locate any invalid line IDs, raise exception if any found
|
118
|
+
invalid_splits = partition_locations.difference(self.line_id_to_index)
|
119
|
+
if invalid_splits:
|
120
|
+
err_msg = (
|
121
|
+
f"{len(invalid_splits)} line ID(s) not found in input: "
|
122
|
+
+ ", ".join(invalid_splits)
|
123
|
+
)
|
124
|
+
log.warning(err_msg)
|
125
|
+
raise OutputParserException(err_msg)
|
126
|
+
|
127
|
+
# Map line IDs to indices (so they can be sorted and lines indexed)
|
128
|
+
index_to_line_id = {0: "START", None: "END"}
|
129
|
+
split_points = {0}
|
130
|
+
for partition in partition_locations:
|
131
|
+
index = self.line_id_to_index[partition]
|
132
|
+
index_to_line_id[index] = partition
|
133
|
+
split_points.add(index)
|
134
|
+
|
135
|
+
# Get partition start/ends, chunks, chunk lengths
|
136
|
+
split_points = sorted(split_points) + [None]
|
137
|
+
partition_indices = list(zip(split_points, split_points[1:]))
|
138
|
+
partition_points = [
|
139
|
+
(index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
|
140
|
+
]
|
141
|
+
chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
|
142
|
+
chunk_tokens = list(map(self.model.get_num_tokens, chunks))
|
143
|
+
|
144
|
+
# Collect any chunks that exceed token limit
|
145
|
+
oversized_indices: list[int] = [
|
146
|
+
i for i, n in enumerate(chunk_tokens) if n > self.token_limit
|
147
|
+
]
|
148
|
+
if oversized_indices:
|
149
|
+
data = list(zip(partition_points, chunks, chunk_tokens))
|
150
|
+
data = [data[i] for i in oversized_indices]
|
151
|
+
|
152
|
+
problem_points = "\n".join(
|
153
|
+
[
|
154
|
+
f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
|
155
|
+
for (i0, i1), _, t in data
|
156
|
+
]
|
157
|
+
)
|
158
|
+
log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
|
159
|
+
log.debug(
|
160
|
+
"Oversized chunks:\n"
|
161
|
+
+ "\n#############\n".join(chunk for _, chunk, _ in data)
|
162
|
+
)
|
163
|
+
raise OutputParserException(
|
164
|
+
f"The following segments are too long and must be "
|
165
|
+
f"further subdivided:\n{problem_points}"
|
166
|
+
)
|
167
|
+
|
168
|
+
return "\n<JANUS_PARTITION>\n".join(chunks)
|
janus/refiners/refiner.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Any
|
2
3
|
|
3
4
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
5
|
+
from langchain_core.exceptions import OutputParserException
|
4
6
|
from langchain_core.output_parsers import StrOutputParser
|
5
7
|
from langchain_core.prompt_values import PromptValue
|
6
8
|
from langchain_core.runnables import RunnableSerializable
|
@@ -25,9 +27,38 @@ class JanusRefiner(JanusParser):
|
|
25
27
|
raise NotImplementedError
|
26
28
|
|
27
29
|
|
30
|
+
class SimpleRetry(JanusRefiner):
|
31
|
+
max_retries: int
|
32
|
+
retry_chain: RunnableSerializable
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
llm: JanusModel,
|
37
|
+
parser: JanusParser,
|
38
|
+
max_retries: int,
|
39
|
+
):
|
40
|
+
retry_chain = llm | StrOutputParser()
|
41
|
+
super().__init__(
|
42
|
+
retry_chain=retry_chain,
|
43
|
+
parser=parser,
|
44
|
+
max_retries=max_retries,
|
45
|
+
)
|
46
|
+
|
47
|
+
def parse_completion(
|
48
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
49
|
+
) -> Any:
|
50
|
+
for retry_number in range(self.max_retries):
|
51
|
+
try:
|
52
|
+
return self.parser.parse(completion)
|
53
|
+
except OutputParserException:
|
54
|
+
completion = self.retry_chain.invoke(prompt_value)
|
55
|
+
|
56
|
+
return self.parser.parse(completion)
|
57
|
+
|
58
|
+
|
28
59
|
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
60
|
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
-
retry_prompt = MODEL_PROMPT_ENGINES[llm.
|
61
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
31
62
|
source_language="text",
|
32
63
|
prompt_template="refinement/fix_exceptions",
|
33
64
|
).prompt
|
@@ -46,6 +77,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
46
77
|
max_retries: int
|
47
78
|
reflection_chain: RunnableSerializable
|
48
79
|
revision_chain: RunnableSerializable
|
80
|
+
reflection_prompt_name: str
|
49
81
|
|
50
82
|
def __init__(
|
51
83
|
self,
|
@@ -54,11 +86,11 @@ class ReflectionRefiner(JanusRefiner):
|
|
54
86
|
max_retries: int,
|
55
87
|
prompt_template_name: str = "refinement/reflection",
|
56
88
|
):
|
57
|
-
reflection_prompt = MODEL_PROMPT_ENGINES[llm.
|
89
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
58
90
|
source_language="text",
|
59
91
|
prompt_template=prompt_template_name,
|
60
92
|
).prompt
|
61
|
-
revision_prompt = MODEL_PROMPT_ENGINES[llm.
|
93
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
62
94
|
source_language="text",
|
63
95
|
prompt_template="refinement/revision",
|
64
96
|
).prompt
|
@@ -66,6 +98,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
66
98
|
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
99
|
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
100
|
super().__init__(
|
101
|
+
reflection_prompt_name=prompt_template_name,
|
69
102
|
reflection_chain=reflection_chain,
|
70
103
|
revision_chain=revision_chain,
|
71
104
|
parser=parser,
|
@@ -75,6 +108,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
75
108
|
def parse_completion(
|
76
109
|
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
110
|
) -> Any:
|
111
|
+
log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
|
78
112
|
for retry_number in range(self.max_retries):
|
79
113
|
reflection = self.reflection_chain.invoke(
|
80
114
|
dict(
|
@@ -82,7 +116,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
82
116
|
completion=completion,
|
83
117
|
)
|
84
118
|
)
|
85
|
-
if
|
119
|
+
if re.search(r"\bLGTM\b", reflection) is not None:
|
86
120
|
return self.parser.parse(completion)
|
87
121
|
if not retry_number:
|
88
122
|
log.info(f"Completion:\n{completion}")
|
@@ -105,11 +139,3 @@ class HallucinationRefiner(ReflectionRefiner):
|
|
105
139
|
prompt_template_name="refinement/hallucination",
|
106
140
|
**kwargs,
|
107
141
|
)
|
108
|
-
|
109
|
-
|
110
|
-
REFINERS = dict(
|
111
|
-
none=JanusRefiner,
|
112
|
-
parser=FixParserExceptions,
|
113
|
-
reflection=ReflectionRefiner,
|
114
|
-
hallucination=HallucinationRefiner,
|
115
|
-
)
|
janus/refiners/uml.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from janus.llm.models_info import JanusModel
|
2
|
+
from janus.parsers.parser import JanusParser
|
3
|
+
from janus.refiners.refiner import ReflectionRefiner
|
4
|
+
|
5
|
+
|
6
|
+
class ALCFixUMLVariablesRefiner(ReflectionRefiner):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
llm: JanusModel,
|
10
|
+
parser: JanusParser,
|
11
|
+
max_retries: int,
|
12
|
+
):
|
13
|
+
super().__init__(
|
14
|
+
llm=llm,
|
15
|
+
parser=parser,
|
16
|
+
max_retries=max_retries,
|
17
|
+
prompt_template_name="refinement/uml/alc_fix_variables",
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class FixUMLConnectionsRefiner(ReflectionRefiner):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
llm: JanusModel,
|
25
|
+
parser: JanusParser,
|
26
|
+
max_retries: int,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
llm=llm,
|
30
|
+
parser=parser,
|
31
|
+
max_retries=max_retries,
|
32
|
+
prompt_template_name="refinement/uml/fix_connections",
|
33
|
+
)
|
janus/retrievers/retriever.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from langchain_core.documents import Document
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
1
5
|
from langchain_core.retrievers import BaseRetriever
|
2
6
|
from langchain_core.runnables import Runnable, RunnableConfig
|
3
7
|
|
4
8
|
from janus.language.block import CodeBlock
|
9
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
10
|
+
from janus.utils.logger import create_logger
|
11
|
+
from janus.utils.pdf_docs_reader import PDFDocsReader
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
5
14
|
|
6
15
|
|
7
16
|
class JanusRetriever(Runnable):
|
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
|
|
40
49
|
docs = self.retriever.invoke(code_block.text)
|
41
50
|
context = "\n\n".join(doc.page_content for doc in docs)
|
42
51
|
return f"You may use the following additional context: {context}"
|
52
|
+
|
53
|
+
|
54
|
+
class LanguageDocsRetriever(JanusRetriever):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
llm: JanusModel,
|
58
|
+
language_name: str,
|
59
|
+
prompt_template_name: str = "retrieval/language_docs",
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.llm: JanusModel = llm
|
63
|
+
self.language: str = language_name
|
64
|
+
|
65
|
+
self.PDF_reader = PDFDocsReader(
|
66
|
+
language=self.language,
|
67
|
+
)
|
68
|
+
|
69
|
+
language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
|
70
|
+
source_language=self.language,
|
71
|
+
prompt_template=prompt_template_name,
|
72
|
+
).prompt
|
73
|
+
|
74
|
+
parser: StrOutputParser = StrOutputParser()
|
75
|
+
self.chain = language_docs_prompt | self.llm | parser
|
76
|
+
|
77
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
78
|
+
functionality_to_reference: str = self.chain.invoke(
|
79
|
+
dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
|
80
|
+
)
|
81
|
+
if functionality_to_reference == "NODOCS":
|
82
|
+
log.debug("No Opcodes requested from language docs retriever.")
|
83
|
+
return ""
|
84
|
+
else:
|
85
|
+
functionality_to_reference: List = functionality_to_reference.split(", ")
|
86
|
+
log.debug(
|
87
|
+
f"List of opcodes requested by language docs retriever"
|
88
|
+
f"to search the {self.language} "
|
89
|
+
f"docs for: {functionality_to_reference}"
|
90
|
+
)
|
91
|
+
|
92
|
+
docs: List[Document] = self.PDF_reader.search_language_reference(
|
93
|
+
functionality_to_reference
|
94
|
+
)
|
95
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
96
|
+
if context:
|
97
|
+
return (
|
98
|
+
f"You may reference the following excerpts from the {self.language} "
|
99
|
+
f"language documentation: {context}"
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
return ""
|
janus/utils/enums.py
CHANGED
@@ -89,6 +89,20 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
89
89
|
"url": "https://github.com/stsewd/tree-sitter-comment",
|
90
90
|
"example": "# This is a comment\n",
|
91
91
|
},
|
92
|
+
"cobol": {
|
93
|
+
"comment": "*",
|
94
|
+
"suffix": "cbl",
|
95
|
+
"url": "https://github.com/yutaro-sakamoto/tree-sitter-cobol",
|
96
|
+
"example": (
|
97
|
+
" IDENTIFICATION DIVISION.\n"
|
98
|
+
" PROGRAM-ID. HelloWorld.\n"
|
99
|
+
" ENVIRONMENT DIVISION.\n"
|
100
|
+
" DATA DIVISION.\n"
|
101
|
+
" PROCEDURE DIVISION.\n"
|
102
|
+
' DISPLAY "Hello, World!".\n'
|
103
|
+
" STOP RUN.\n"
|
104
|
+
),
|
105
|
+
},
|
92
106
|
"commonlisp": {
|
93
107
|
"comment": ";;",
|
94
108
|
"suffix": "lisp",
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import joblib
|
7
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8
|
+
from langchain_core.documents import Document
|
9
|
+
from langchain_unstructured import UnstructuredLoader
|
10
|
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
11
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
12
|
+
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PDFDocsReader:
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
language: str,
|
22
|
+
chunk_size: int = 1000,
|
23
|
+
chunk_overlap: int = 100,
|
24
|
+
start_page: Optional[int] = None,
|
25
|
+
end_page: Optional[int] = None,
|
26
|
+
vectorizer: CountVectorizer = TfidfVectorizer(),
|
27
|
+
):
|
28
|
+
self.retrieval_docs_dir: Path = Path(
|
29
|
+
os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
|
30
|
+
)
|
31
|
+
self.language = language
|
32
|
+
self.chunk_size = chunk_size
|
33
|
+
self.chunk_overlap = chunk_overlap
|
34
|
+
self.start_page = start_page
|
35
|
+
self.end_page = end_page
|
36
|
+
self.vectorizer = vectorizer
|
37
|
+
self.documents = self.load_and_chunk_pdf()
|
38
|
+
self.doc_vectors = self.vectorize_documents()
|
39
|
+
|
40
|
+
def load_and_chunk_pdf(self) -> List[str]:
|
41
|
+
pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
|
42
|
+
pickled_documents_path = (
|
43
|
+
self.retrieval_docs_dir / f"{self.language}_documents.pkl"
|
44
|
+
)
|
45
|
+
|
46
|
+
if pickled_documents_path.exists():
|
47
|
+
log.debug(
|
48
|
+
f"Loading pre-chunked PDF from {pickled_documents_path}. "
|
49
|
+
f"If you want to regenerate retrieval docs for {self.language}, "
|
50
|
+
f"delete the file at {pickled_documents_path}, "
|
51
|
+
f"then add a new {self.language}.pdf."
|
52
|
+
)
|
53
|
+
documents = joblib.load(pickled_documents_path)
|
54
|
+
else:
|
55
|
+
if not pdf_path.exists():
|
56
|
+
raise FileNotFoundError(
|
57
|
+
f"Language docs retrieval is enabled, but no PDF for language "
|
58
|
+
f"'{self.language}' was found. Move a "
|
59
|
+
f"{self.language} reference manual to "
|
60
|
+
f"{pdf_path.absolute()} "
|
61
|
+
f"(the path to the directory of PDF docs can be "
|
62
|
+
f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
|
63
|
+
)
|
64
|
+
log.info(
|
65
|
+
f"Chunking reference PDF for {self.language} using unstructured - "
|
66
|
+
f"if your PDF has many pages, this could take a while..."
|
67
|
+
)
|
68
|
+
start_time = time.time()
|
69
|
+
loader = UnstructuredLoader(
|
70
|
+
pdf_path,
|
71
|
+
chunking_strategy="basic",
|
72
|
+
max_characters=1000000,
|
73
|
+
include_orig_elements=False,
|
74
|
+
start_page=self.start_page,
|
75
|
+
end_page=self.end_page,
|
76
|
+
)
|
77
|
+
docs = loader.load()
|
78
|
+
text = "\n\n".join([doc.page_content for doc in docs])
|
79
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
80
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
81
|
+
)
|
82
|
+
documents = text_splitter.split_text(text)
|
83
|
+
log.info(f"Document store created for language: {self.language}")
|
84
|
+
end_time = time.time()
|
85
|
+
log.info(
|
86
|
+
f"Processing time for {self.language} PDF: "
|
87
|
+
f"{end_time - start_time} seconds"
|
88
|
+
)
|
89
|
+
|
90
|
+
joblib.dump(documents, pickled_documents_path)
|
91
|
+
log.debug(f"Documents saved to {pickled_documents_path}")
|
92
|
+
|
93
|
+
return documents
|
94
|
+
|
95
|
+
def vectorize_documents(self) -> (TfidfVectorizer, any):
|
96
|
+
doc_vectors = self.vectorizer.fit_transform(self.documents)
|
97
|
+
return doc_vectors
|
98
|
+
|
99
|
+
def search_language_reference(
|
100
|
+
self,
|
101
|
+
query: List[str],
|
102
|
+
top_k: int = 1,
|
103
|
+
min_similarity: float = 0.1,
|
104
|
+
) -> List[Document]:
|
105
|
+
"""Searches through the vectorized PDF for the query using
|
106
|
+
tf-idf and returns a list of langchain Documents."""
|
107
|
+
|
108
|
+
docs: List[Document] = []
|
109
|
+
|
110
|
+
for item in query:
|
111
|
+
# Transform the query using the TF-IDF vectorizer
|
112
|
+
query_vector = self.vectorizer.transform([item])
|
113
|
+
|
114
|
+
# Calculate cosine similarities between the query and document vectors
|
115
|
+
similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
|
116
|
+
|
117
|
+
# Get the indices of documents with similarity above the threshold
|
118
|
+
valid_indices = [
|
119
|
+
i for i, sim in enumerate(similarities) if sim >= min_similarity
|
120
|
+
]
|
121
|
+
|
122
|
+
# Sort the valid indices by similarity score in descending order
|
123
|
+
sorted_indices = sorted(
|
124
|
+
valid_indices, key=lambda i: similarities[i], reverse=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Limit to top-k results
|
128
|
+
top_indices = sorted_indices[:top_k]
|
129
|
+
|
130
|
+
# Retrieve the top-k most relevant documents
|
131
|
+
docs += [Document(page_content=self.documents[i]) for i in top_indices]
|
132
|
+
log.debug(f"Langauge documentation search result: {docs}")
|
133
|
+
|
134
|
+
return docs
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: janus-llm
|
3
|
-
Version: 4.1
|
3
|
+
Version: 4.3.1
|
4
4
|
Summary: A transcoding library using LLMs.
|
5
5
|
Home-page: https://github.com/janus-llm/janus-llm
|
6
6
|
License: Apache 2.0
|
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
|
|
23
23
|
Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
|
24
24
|
Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
|
25
25
|
Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
|
26
|
+
Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
|
26
27
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
27
28
|
Requires-Dist: numpy (>=1.24.3,<2.0.0)
|
28
29
|
Requires-Dist: openai (>=1.14.0,<2.0.0)
|
30
|
+
Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
|
29
31
|
Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
|
30
32
|
Requires-Dist: py-rouge (>=1.1,<2.0)
|
33
|
+
Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
|
31
34
|
Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
|
32
35
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
33
36
|
Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
|
37
|
+
Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
|
34
38
|
Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
|
39
|
+
Requires-Dist: tesseract (>=0.1.3,<0.2.0)
|
35
40
|
Requires-Dist: text-generation (>=0.6.0,<0.7.0)
|
36
41
|
Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
|
37
42
|
Requires-Dist: transformers (>=4.31.0,<5.0.0)
|
38
43
|
Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
|
39
44
|
Requires-Dist: typer (>=0.9.0,<0.10.0)
|
45
|
+
Requires-Dist: unstructured (>=0.15.9,<0.16.0)
|
46
|
+
Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
|
47
|
+
Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
|
40
48
|
Project-URL: Documentation, https://janus-llm.github.io/janus-llm
|
41
49
|
Project-URL: Repository, https://github.com/janus-llm/janus-llm
|
42
50
|
Description-Content-Type: text/markdown
|