janus-llm 4.1.0__py3-none-any.whl → 4.3.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +286 -30
- janus/converter/__init__.py +1 -0
- janus/converter/converter.py +46 -47
- janus/converter/evaluate.py +230 -4
- janus/converter/partition.py +27 -0
- janus/language/alc/_tests/test_alc.py +1 -1
- janus/language/alc/alc.py +9 -4
- janus/language/combine.py +22 -0
- janus/language/splitter.py +31 -23
- janus/language/treesitter/treesitter.py +9 -1
- janus/llm/models_info.py +20 -12
- janus/parsers/eval_parsers/incose_parser.py +134 -0
- janus/parsers/eval_parsers/inline_comment_parser.py +112 -0
- janus/parsers/partition_parser.py +168 -0
- janus/refiners/refiner.py +38 -12
- janus/refiners/uml.py +33 -0
- janus/retrievers/retriever.py +60 -0
- janus/utils/enums.py +14 -0
- janus/utils/pdf_docs_reader.py +134 -0
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/METADATA +9 -1
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/RECORD +25 -19
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/WHEEL +1 -1
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/LICENSE +0 -0
- {janus_llm-4.1.0.dist-info → janus_llm-4.3.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.messages import BaseMessage
|
8
|
+
from langchain_core.pydantic_v1 import BaseModel, Field, conint
|
9
|
+
|
10
|
+
from janus.language.block import CodeBlock
|
11
|
+
from janus.parsers.parser import JanusParser
|
12
|
+
from janus.utils.logger import create_logger
|
13
|
+
|
14
|
+
log = create_logger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class Criteria(BaseModel):
|
18
|
+
reasoning: str = Field(description="A short explanation for the given score")
|
19
|
+
# Constrained to an integer between 1 and 4
|
20
|
+
score: conint(ge=1, le=4) = Field( # type: ignore
|
21
|
+
description="An integer score between 1 and 4 (inclusive), 4 being the best"
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
class Comment(BaseModel):
|
26
|
+
comment_id: str = Field(description="The 8-character comment ID")
|
27
|
+
completeness: Criteria = Field(description="The completeness of the comment")
|
28
|
+
hallucination: Criteria = Field(description="The factualness of the comment")
|
29
|
+
readability: Criteria = Field(description="The readability of the comment")
|
30
|
+
usefulness: Criteria = Field(description="The usefulness of the comment")
|
31
|
+
|
32
|
+
|
33
|
+
class CommentList(BaseModel):
|
34
|
+
__root__: list[Comment] = Field(
|
35
|
+
description=(
|
36
|
+
"A list of inline comment evaluations. Each element should include"
|
37
|
+
" the comment's 8-character ID in the `comment_id` field, and four"
|
38
|
+
" score objects corresponding to each metric (`completeness`,"
|
39
|
+
" `hallucination`, `readability`, and `usefulness`)."
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class InlineCommentParser(JanusParser, PydanticOutputParser):
|
45
|
+
comments: dict[str, str]
|
46
|
+
|
47
|
+
def __init__(self):
|
48
|
+
PydanticOutputParser.__init__(
|
49
|
+
self,
|
50
|
+
pydantic_object=CommentList,
|
51
|
+
comments=[],
|
52
|
+
)
|
53
|
+
|
54
|
+
def parse_input(self, block: CodeBlock) -> str:
|
55
|
+
# TODO: Perform comment stripping/placeholding here rather than in script
|
56
|
+
text = super().parse_input(block)
|
57
|
+
self.comments = dict(
|
58
|
+
re.findall(
|
59
|
+
r"<(?:BLOCK|INLINE)_COMMENT (\w{8})> (.*)$",
|
60
|
+
text,
|
61
|
+
flags=re.MULTILINE,
|
62
|
+
)
|
63
|
+
)
|
64
|
+
return text
|
65
|
+
|
66
|
+
def parse(self, text: str | BaseMessage) -> str:
|
67
|
+
if isinstance(text, BaseMessage):
|
68
|
+
text = str(text.content)
|
69
|
+
|
70
|
+
# Strip everything outside the JSON object
|
71
|
+
begin, end = text.find("["), text.rfind("]")
|
72
|
+
text = text[begin : end + 1]
|
73
|
+
|
74
|
+
try:
|
75
|
+
out: CommentList = super().parse(text)
|
76
|
+
except json.JSONDecodeError as e:
|
77
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
78
|
+
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
79
|
+
|
80
|
+
evals: dict[str, Any] = {c.comment_id: c.dict() for c in out.__root__}
|
81
|
+
|
82
|
+
seen_keys = set(evals.keys())
|
83
|
+
expected_keys = set(self.comments.keys())
|
84
|
+
missing_keys = expected_keys.difference(seen_keys)
|
85
|
+
invalid_keys = seen_keys.difference(expected_keys)
|
86
|
+
if missing_keys:
|
87
|
+
log.debug(f"Missing keys: {missing_keys}")
|
88
|
+
if invalid_keys:
|
89
|
+
log.debug(f"Invalid keys: {invalid_keys}")
|
90
|
+
log.debug(f"Missing keys: {missing_keys}")
|
91
|
+
raise OutputParserException(
|
92
|
+
f"Got invalid return object. Missing the following expected "
|
93
|
+
f"keys: {missing_keys}"
|
94
|
+
)
|
95
|
+
|
96
|
+
for key in invalid_keys:
|
97
|
+
del evals[key]
|
98
|
+
|
99
|
+
for cid in evals.keys():
|
100
|
+
evals[cid]["comment"] = self.comments[cid]
|
101
|
+
evals[cid].pop("comment_id")
|
102
|
+
|
103
|
+
return json.dumps(evals)
|
104
|
+
|
105
|
+
def parse_combined_output(self, text: str) -> str:
|
106
|
+
if not text.strip():
|
107
|
+
return str({})
|
108
|
+
objs = [json.loads(line.strip()) for line in text.split("\n") if line.strip()]
|
109
|
+
output_obj = {}
|
110
|
+
for obj in objs:
|
111
|
+
output_obj.update(obj)
|
112
|
+
return json.dumps(output_obj)
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.language_models import BaseLanguageModel
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class PartitionObject(BaseModel):
|
20
|
+
reasoning: str = Field(
|
21
|
+
description="An explanation for why the code should be split at this point"
|
22
|
+
)
|
23
|
+
location: str = Field(
|
24
|
+
description="The 8-character line label which should start a new chunk"
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class PartitionList(BaseModel):
|
29
|
+
__root__: list[PartitionObject] = Field(
|
30
|
+
description=(
|
31
|
+
"A list of appropriate split points, each with a `reasoning` field "
|
32
|
+
"that explains a justification for splitting the code at that point, "
|
33
|
+
"and a `location` field which is simply the 8-character line ID. "
|
34
|
+
"The `reasoning` field should always be included first."
|
35
|
+
)
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
# The following IDs appear in the prompt example. If the LLM produces them,
|
40
|
+
# they should be ignored
|
41
|
+
EXAMPLE_IDS = {
|
42
|
+
"0d2f4f8d",
|
43
|
+
"def2a953",
|
44
|
+
"75315253",
|
45
|
+
"e7f928da",
|
46
|
+
"1781b2a9",
|
47
|
+
"2fe21e27",
|
48
|
+
"9aef6179",
|
49
|
+
"6061bd82",
|
50
|
+
"22bd0c30",
|
51
|
+
"5d85e19e",
|
52
|
+
"06027969",
|
53
|
+
"91b722fb",
|
54
|
+
"4b3f79be",
|
55
|
+
"k57w964a",
|
56
|
+
"51638s96",
|
57
|
+
"065o6q32",
|
58
|
+
"j5q6p852",
|
59
|
+
}
|
60
|
+
|
61
|
+
|
62
|
+
class PartitionParser(JanusParser, PydanticOutputParser):
|
63
|
+
token_limit: int
|
64
|
+
model: BaseLanguageModel
|
65
|
+
lines: list[str] = []
|
66
|
+
line_id_to_index: dict[str, int] = {}
|
67
|
+
|
68
|
+
def __init__(self, token_limit: int, model: BaseLanguageModel):
|
69
|
+
PydanticOutputParser.__init__(
|
70
|
+
self,
|
71
|
+
pydantic_object=PartitionList,
|
72
|
+
model=model,
|
73
|
+
token_limit=token_limit,
|
74
|
+
)
|
75
|
+
|
76
|
+
def parse_input(self, block: CodeBlock) -> str:
|
77
|
+
code = str(block.text)
|
78
|
+
RNG.seed(code)
|
79
|
+
|
80
|
+
self.lines = code.split("\n")
|
81
|
+
|
82
|
+
# Generate a unique ID for each line (ensure they are unique)
|
83
|
+
line_ids = set()
|
84
|
+
while len(line_ids) < len(self.lines):
|
85
|
+
line_id = str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8]
|
86
|
+
if line_id in EXAMPLE_IDS:
|
87
|
+
continue
|
88
|
+
line_ids.add(line_id)
|
89
|
+
|
90
|
+
# Prepend each line with the corresponding ID, save the mapping
|
91
|
+
self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
|
92
|
+
processed = "\n".join(
|
93
|
+
f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
|
94
|
+
)
|
95
|
+
return processed
|
96
|
+
|
97
|
+
def parse(self, text: str | BaseMessage) -> str:
|
98
|
+
if isinstance(text, BaseMessage):
|
99
|
+
text = str(text.content)
|
100
|
+
|
101
|
+
# Strip everything outside the JSON object
|
102
|
+
begin, end = text.find("["), text.rfind("]")
|
103
|
+
text = text[begin : end + 1]
|
104
|
+
|
105
|
+
try:
|
106
|
+
out: PartitionList = super().parse(text)
|
107
|
+
except (OutputParserException, json.JSONDecodeError):
|
108
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
109
|
+
raise
|
110
|
+
|
111
|
+
# Get partition locations, discard reasoning
|
112
|
+
partition_locations = {partition.location for partition in out.__root__}
|
113
|
+
|
114
|
+
# Ignore IDs from the example input
|
115
|
+
partition_locations.difference_update(EXAMPLE_IDS)
|
116
|
+
|
117
|
+
# Locate any invalid line IDs, raise exception if any found
|
118
|
+
invalid_splits = partition_locations.difference(self.line_id_to_index)
|
119
|
+
if invalid_splits:
|
120
|
+
err_msg = (
|
121
|
+
f"{len(invalid_splits)} line ID(s) not found in input: "
|
122
|
+
+ ", ".join(invalid_splits)
|
123
|
+
)
|
124
|
+
log.warning(err_msg)
|
125
|
+
raise OutputParserException(err_msg)
|
126
|
+
|
127
|
+
# Map line IDs to indices (so they can be sorted and lines indexed)
|
128
|
+
index_to_line_id = {0: "START", None: "END"}
|
129
|
+
split_points = {0}
|
130
|
+
for partition in partition_locations:
|
131
|
+
index = self.line_id_to_index[partition]
|
132
|
+
index_to_line_id[index] = partition
|
133
|
+
split_points.add(index)
|
134
|
+
|
135
|
+
# Get partition start/ends, chunks, chunk lengths
|
136
|
+
split_points = sorted(split_points) + [None]
|
137
|
+
partition_indices = list(zip(split_points, split_points[1:]))
|
138
|
+
partition_points = [
|
139
|
+
(index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
|
140
|
+
]
|
141
|
+
chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
|
142
|
+
chunk_tokens = list(map(self.model.get_num_tokens, chunks))
|
143
|
+
|
144
|
+
# Collect any chunks that exceed token limit
|
145
|
+
oversized_indices: list[int] = [
|
146
|
+
i for i, n in enumerate(chunk_tokens) if n > self.token_limit
|
147
|
+
]
|
148
|
+
if oversized_indices:
|
149
|
+
data = list(zip(partition_points, chunks, chunk_tokens))
|
150
|
+
data = [data[i] for i in oversized_indices]
|
151
|
+
|
152
|
+
problem_points = "\n".join(
|
153
|
+
[
|
154
|
+
f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
|
155
|
+
for (i0, i1), _, t in data
|
156
|
+
]
|
157
|
+
)
|
158
|
+
log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
|
159
|
+
log.debug(
|
160
|
+
"Oversized chunks:\n"
|
161
|
+
+ "\n#############\n".join(chunk for _, chunk, _ in data)
|
162
|
+
)
|
163
|
+
raise OutputParserException(
|
164
|
+
f"The following segments are too long and must be "
|
165
|
+
f"further subdivided:\n{problem_points}"
|
166
|
+
)
|
167
|
+
|
168
|
+
return "\n<JANUS_PARTITION>\n".join(chunks)
|
janus/refiners/refiner.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Any
|
2
3
|
|
3
4
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
5
|
+
from langchain_core.exceptions import OutputParserException
|
4
6
|
from langchain_core.output_parsers import StrOutputParser
|
5
7
|
from langchain_core.prompt_values import PromptValue
|
6
8
|
from langchain_core.runnables import RunnableSerializable
|
@@ -25,9 +27,38 @@ class JanusRefiner(JanusParser):
|
|
25
27
|
raise NotImplementedError
|
26
28
|
|
27
29
|
|
30
|
+
class SimpleRetry(JanusRefiner):
|
31
|
+
max_retries: int
|
32
|
+
retry_chain: RunnableSerializable
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
llm: JanusModel,
|
37
|
+
parser: JanusParser,
|
38
|
+
max_retries: int,
|
39
|
+
):
|
40
|
+
retry_chain = llm | StrOutputParser()
|
41
|
+
super().__init__(
|
42
|
+
retry_chain=retry_chain,
|
43
|
+
parser=parser,
|
44
|
+
max_retries=max_retries,
|
45
|
+
)
|
46
|
+
|
47
|
+
def parse_completion(
|
48
|
+
self, completion: str, prompt_value: PromptValue, **kwargs
|
49
|
+
) -> Any:
|
50
|
+
for retry_number in range(self.max_retries):
|
51
|
+
try:
|
52
|
+
return self.parser.parse(completion)
|
53
|
+
except OutputParserException:
|
54
|
+
completion = self.retry_chain.invoke(prompt_value)
|
55
|
+
|
56
|
+
return self.parser.parse(completion)
|
57
|
+
|
58
|
+
|
28
59
|
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
60
|
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
-
retry_prompt = MODEL_PROMPT_ENGINES[llm.
|
61
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
31
62
|
source_language="text",
|
32
63
|
prompt_template="refinement/fix_exceptions",
|
33
64
|
).prompt
|
@@ -46,6 +77,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
46
77
|
max_retries: int
|
47
78
|
reflection_chain: RunnableSerializable
|
48
79
|
revision_chain: RunnableSerializable
|
80
|
+
reflection_prompt_name: str
|
49
81
|
|
50
82
|
def __init__(
|
51
83
|
self,
|
@@ -54,11 +86,11 @@ class ReflectionRefiner(JanusRefiner):
|
|
54
86
|
max_retries: int,
|
55
87
|
prompt_template_name: str = "refinement/reflection",
|
56
88
|
):
|
57
|
-
reflection_prompt = MODEL_PROMPT_ENGINES[llm.
|
89
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
58
90
|
source_language="text",
|
59
91
|
prompt_template=prompt_template_name,
|
60
92
|
).prompt
|
61
|
-
revision_prompt = MODEL_PROMPT_ENGINES[llm.
|
93
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
62
94
|
source_language="text",
|
63
95
|
prompt_template="refinement/revision",
|
64
96
|
).prompt
|
@@ -66,6 +98,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
66
98
|
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
99
|
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
100
|
super().__init__(
|
101
|
+
reflection_prompt_name=prompt_template_name,
|
69
102
|
reflection_chain=reflection_chain,
|
70
103
|
revision_chain=revision_chain,
|
71
104
|
parser=parser,
|
@@ -75,6 +108,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
75
108
|
def parse_completion(
|
76
109
|
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
110
|
) -> Any:
|
111
|
+
log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
|
78
112
|
for retry_number in range(self.max_retries):
|
79
113
|
reflection = self.reflection_chain.invoke(
|
80
114
|
dict(
|
@@ -82,7 +116,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
82
116
|
completion=completion,
|
83
117
|
)
|
84
118
|
)
|
85
|
-
if
|
119
|
+
if re.search(r"\bLGTM\b", reflection) is not None:
|
86
120
|
return self.parser.parse(completion)
|
87
121
|
if not retry_number:
|
88
122
|
log.info(f"Completion:\n{completion}")
|
@@ -105,11 +139,3 @@ class HallucinationRefiner(ReflectionRefiner):
|
|
105
139
|
prompt_template_name="refinement/hallucination",
|
106
140
|
**kwargs,
|
107
141
|
)
|
108
|
-
|
109
|
-
|
110
|
-
REFINERS = dict(
|
111
|
-
none=JanusRefiner,
|
112
|
-
parser=FixParserExceptions,
|
113
|
-
reflection=ReflectionRefiner,
|
114
|
-
hallucination=HallucinationRefiner,
|
115
|
-
)
|
janus/refiners/uml.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from janus.llm.models_info import JanusModel
|
2
|
+
from janus.parsers.parser import JanusParser
|
3
|
+
from janus.refiners.refiner import ReflectionRefiner
|
4
|
+
|
5
|
+
|
6
|
+
class ALCFixUMLVariablesRefiner(ReflectionRefiner):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
llm: JanusModel,
|
10
|
+
parser: JanusParser,
|
11
|
+
max_retries: int,
|
12
|
+
):
|
13
|
+
super().__init__(
|
14
|
+
llm=llm,
|
15
|
+
parser=parser,
|
16
|
+
max_retries=max_retries,
|
17
|
+
prompt_template_name="refinement/uml/alc_fix_variables",
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class FixUMLConnectionsRefiner(ReflectionRefiner):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
llm: JanusModel,
|
25
|
+
parser: JanusParser,
|
26
|
+
max_retries: int,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
llm=llm,
|
30
|
+
parser=parser,
|
31
|
+
max_retries=max_retries,
|
32
|
+
prompt_template_name="refinement/uml/fix_connections",
|
33
|
+
)
|
janus/retrievers/retriever.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from langchain_core.documents import Document
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
1
5
|
from langchain_core.retrievers import BaseRetriever
|
2
6
|
from langchain_core.runnables import Runnable, RunnableConfig
|
3
7
|
|
4
8
|
from janus.language.block import CodeBlock
|
9
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
10
|
+
from janus.utils.logger import create_logger
|
11
|
+
from janus.utils.pdf_docs_reader import PDFDocsReader
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
5
14
|
|
6
15
|
|
7
16
|
class JanusRetriever(Runnable):
|
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
|
|
40
49
|
docs = self.retriever.invoke(code_block.text)
|
41
50
|
context = "\n\n".join(doc.page_content for doc in docs)
|
42
51
|
return f"You may use the following additional context: {context}"
|
52
|
+
|
53
|
+
|
54
|
+
class LanguageDocsRetriever(JanusRetriever):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
llm: JanusModel,
|
58
|
+
language_name: str,
|
59
|
+
prompt_template_name: str = "retrieval/language_docs",
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.llm: JanusModel = llm
|
63
|
+
self.language: str = language_name
|
64
|
+
|
65
|
+
self.PDF_reader = PDFDocsReader(
|
66
|
+
language=self.language,
|
67
|
+
)
|
68
|
+
|
69
|
+
language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
|
70
|
+
source_language=self.language,
|
71
|
+
prompt_template=prompt_template_name,
|
72
|
+
).prompt
|
73
|
+
|
74
|
+
parser: StrOutputParser = StrOutputParser()
|
75
|
+
self.chain = language_docs_prompt | self.llm | parser
|
76
|
+
|
77
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
78
|
+
functionality_to_reference: str = self.chain.invoke(
|
79
|
+
dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
|
80
|
+
)
|
81
|
+
if functionality_to_reference == "NODOCS":
|
82
|
+
log.debug("No Opcodes requested from language docs retriever.")
|
83
|
+
return ""
|
84
|
+
else:
|
85
|
+
functionality_to_reference: List = functionality_to_reference.split(", ")
|
86
|
+
log.debug(
|
87
|
+
f"List of opcodes requested by language docs retriever"
|
88
|
+
f"to search the {self.language} "
|
89
|
+
f"docs for: {functionality_to_reference}"
|
90
|
+
)
|
91
|
+
|
92
|
+
docs: List[Document] = self.PDF_reader.search_language_reference(
|
93
|
+
functionality_to_reference
|
94
|
+
)
|
95
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
96
|
+
if context:
|
97
|
+
return (
|
98
|
+
f"You may reference the following excerpts from the {self.language} "
|
99
|
+
f"language documentation: {context}"
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
return ""
|
janus/utils/enums.py
CHANGED
@@ -89,6 +89,20 @@ LANGUAGES: Dict[str, Dict[str, Any]] = {
|
|
89
89
|
"url": "https://github.com/stsewd/tree-sitter-comment",
|
90
90
|
"example": "# This is a comment\n",
|
91
91
|
},
|
92
|
+
"cobol": {
|
93
|
+
"comment": "*",
|
94
|
+
"suffix": "cbl",
|
95
|
+
"url": "https://github.com/yutaro-sakamoto/tree-sitter-cobol",
|
96
|
+
"example": (
|
97
|
+
" IDENTIFICATION DIVISION.\n"
|
98
|
+
" PROGRAM-ID. HelloWorld.\n"
|
99
|
+
" ENVIRONMENT DIVISION.\n"
|
100
|
+
" DATA DIVISION.\n"
|
101
|
+
" PROCEDURE DIVISION.\n"
|
102
|
+
' DISPLAY "Hello, World!".\n'
|
103
|
+
" STOP RUN.\n"
|
104
|
+
),
|
105
|
+
},
|
92
106
|
"commonlisp": {
|
93
107
|
"comment": ";;",
|
94
108
|
"suffix": "lisp",
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import joblib
|
7
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8
|
+
from langchain_core.documents import Document
|
9
|
+
from langchain_unstructured import UnstructuredLoader
|
10
|
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
11
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
12
|
+
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PDFDocsReader:
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
language: str,
|
22
|
+
chunk_size: int = 1000,
|
23
|
+
chunk_overlap: int = 100,
|
24
|
+
start_page: Optional[int] = None,
|
25
|
+
end_page: Optional[int] = None,
|
26
|
+
vectorizer: CountVectorizer = TfidfVectorizer(),
|
27
|
+
):
|
28
|
+
self.retrieval_docs_dir: Path = Path(
|
29
|
+
os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
|
30
|
+
)
|
31
|
+
self.language = language
|
32
|
+
self.chunk_size = chunk_size
|
33
|
+
self.chunk_overlap = chunk_overlap
|
34
|
+
self.start_page = start_page
|
35
|
+
self.end_page = end_page
|
36
|
+
self.vectorizer = vectorizer
|
37
|
+
self.documents = self.load_and_chunk_pdf()
|
38
|
+
self.doc_vectors = self.vectorize_documents()
|
39
|
+
|
40
|
+
def load_and_chunk_pdf(self) -> List[str]:
|
41
|
+
pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
|
42
|
+
pickled_documents_path = (
|
43
|
+
self.retrieval_docs_dir / f"{self.language}_documents.pkl"
|
44
|
+
)
|
45
|
+
|
46
|
+
if pickled_documents_path.exists():
|
47
|
+
log.debug(
|
48
|
+
f"Loading pre-chunked PDF from {pickled_documents_path}. "
|
49
|
+
f"If you want to regenerate retrieval docs for {self.language}, "
|
50
|
+
f"delete the file at {pickled_documents_path}, "
|
51
|
+
f"then add a new {self.language}.pdf."
|
52
|
+
)
|
53
|
+
documents = joblib.load(pickled_documents_path)
|
54
|
+
else:
|
55
|
+
if not pdf_path.exists():
|
56
|
+
raise FileNotFoundError(
|
57
|
+
f"Language docs retrieval is enabled, but no PDF for language "
|
58
|
+
f"'{self.language}' was found. Move a "
|
59
|
+
f"{self.language} reference manual to "
|
60
|
+
f"{pdf_path.absolute()} "
|
61
|
+
f"(the path to the directory of PDF docs can be "
|
62
|
+
f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
|
63
|
+
)
|
64
|
+
log.info(
|
65
|
+
f"Chunking reference PDF for {self.language} using unstructured - "
|
66
|
+
f"if your PDF has many pages, this could take a while..."
|
67
|
+
)
|
68
|
+
start_time = time.time()
|
69
|
+
loader = UnstructuredLoader(
|
70
|
+
pdf_path,
|
71
|
+
chunking_strategy="basic",
|
72
|
+
max_characters=1000000,
|
73
|
+
include_orig_elements=False,
|
74
|
+
start_page=self.start_page,
|
75
|
+
end_page=self.end_page,
|
76
|
+
)
|
77
|
+
docs = loader.load()
|
78
|
+
text = "\n\n".join([doc.page_content for doc in docs])
|
79
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
80
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
81
|
+
)
|
82
|
+
documents = text_splitter.split_text(text)
|
83
|
+
log.info(f"Document store created for language: {self.language}")
|
84
|
+
end_time = time.time()
|
85
|
+
log.info(
|
86
|
+
f"Processing time for {self.language} PDF: "
|
87
|
+
f"{end_time - start_time} seconds"
|
88
|
+
)
|
89
|
+
|
90
|
+
joblib.dump(documents, pickled_documents_path)
|
91
|
+
log.debug(f"Documents saved to {pickled_documents_path}")
|
92
|
+
|
93
|
+
return documents
|
94
|
+
|
95
|
+
def vectorize_documents(self) -> (TfidfVectorizer, any):
|
96
|
+
doc_vectors = self.vectorizer.fit_transform(self.documents)
|
97
|
+
return doc_vectors
|
98
|
+
|
99
|
+
def search_language_reference(
|
100
|
+
self,
|
101
|
+
query: List[str],
|
102
|
+
top_k: int = 1,
|
103
|
+
min_similarity: float = 0.1,
|
104
|
+
) -> List[Document]:
|
105
|
+
"""Searches through the vectorized PDF for the query using
|
106
|
+
tf-idf and returns a list of langchain Documents."""
|
107
|
+
|
108
|
+
docs: List[Document] = []
|
109
|
+
|
110
|
+
for item in query:
|
111
|
+
# Transform the query using the TF-IDF vectorizer
|
112
|
+
query_vector = self.vectorizer.transform([item])
|
113
|
+
|
114
|
+
# Calculate cosine similarities between the query and document vectors
|
115
|
+
similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
|
116
|
+
|
117
|
+
# Get the indices of documents with similarity above the threshold
|
118
|
+
valid_indices = [
|
119
|
+
i for i, sim in enumerate(similarities) if sim >= min_similarity
|
120
|
+
]
|
121
|
+
|
122
|
+
# Sort the valid indices by similarity score in descending order
|
123
|
+
sorted_indices = sorted(
|
124
|
+
valid_indices, key=lambda i: similarities[i], reverse=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Limit to top-k results
|
128
|
+
top_indices = sorted_indices[:top_k]
|
129
|
+
|
130
|
+
# Retrieve the top-k most relevant documents
|
131
|
+
docs += [Document(page_content=self.documents[i]) for i in top_indices]
|
132
|
+
log.debug(f"Langauge documentation search result: {docs}")
|
133
|
+
|
134
|
+
return docs
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: janus-llm
|
3
|
-
Version: 4.1
|
3
|
+
Version: 4.3.1
|
4
4
|
Summary: A transcoding library using LLMs.
|
5
5
|
Home-page: https://github.com/janus-llm/janus-llm
|
6
6
|
License: Apache 2.0
|
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
|
|
23
23
|
Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
|
24
24
|
Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
|
25
25
|
Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
|
26
|
+
Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
|
26
27
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
27
28
|
Requires-Dist: numpy (>=1.24.3,<2.0.0)
|
28
29
|
Requires-Dist: openai (>=1.14.0,<2.0.0)
|
30
|
+
Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
|
29
31
|
Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
|
30
32
|
Requires-Dist: py-rouge (>=1.1,<2.0)
|
33
|
+
Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
|
31
34
|
Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
|
32
35
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
33
36
|
Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
|
37
|
+
Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
|
34
38
|
Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
|
39
|
+
Requires-Dist: tesseract (>=0.1.3,<0.2.0)
|
35
40
|
Requires-Dist: text-generation (>=0.6.0,<0.7.0)
|
36
41
|
Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
|
37
42
|
Requires-Dist: transformers (>=4.31.0,<5.0.0)
|
38
43
|
Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
|
39
44
|
Requires-Dist: typer (>=0.9.0,<0.10.0)
|
45
|
+
Requires-Dist: unstructured (>=0.15.9,<0.16.0)
|
46
|
+
Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
|
47
|
+
Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
|
40
48
|
Project-URL: Documentation, https://janus-llm.github.io/janus-llm
|
41
49
|
Project-URL: Repository, https://github.com/janus-llm/janus-llm
|
42
50
|
Description-Content-Type: text/markdown
|