janus-llm 4.0.0__py3-none-any.whl → 4.2.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- janus/__init__.py +1 -1
- janus/cli.py +161 -26
- janus/converter/__init__.py +1 -0
- janus/converter/_tests/test_translate.py +2 -2
- janus/converter/converter.py +45 -47
- janus/converter/partition.py +27 -0
- janus/language/combine.py +22 -0
- janus/llm/model_callbacks.py +9 -0
- janus/llm/models_info.py +41 -17
- janus/parsers/partition_parser.py +136 -0
- janus/refiners/refiner.py +8 -12
- janus/refiners/uml.py +33 -0
- janus/retrievers/retriever.py +60 -0
- janus/utils/pdf_docs_reader.py +134 -0
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/METADATA +9 -1
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/RECORD +19 -15
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/WHEEL +1 -1
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/LICENSE +0 -0
- {janus_llm-4.0.0.dist-info → janus_llm-4.2.0.dist-info}/entry_points.txt +0 -0
janus/llm/models_info.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
-
import time
|
4
3
|
from pathlib import Path
|
5
|
-
from typing import Protocol, TypeVar
|
4
|
+
from typing import Callable, Protocol, TypeVar
|
6
5
|
|
7
6
|
from dotenv import load_dotenv
|
8
7
|
from langchain_community.llms import HuggingFaceTextGenInference
|
9
8
|
from langchain_core.runnables import Runnable
|
10
|
-
from langchain_openai import
|
9
|
+
from langchain_openai import AzureChatOpenAI
|
11
10
|
|
12
|
-
from janus.llm.model_callbacks import COST_PER_1K_TOKENS,
|
11
|
+
from janus.llm.model_callbacks import COST_PER_1K_TOKENS, azure_model_reroutes
|
13
12
|
from janus.prompts.prompt import (
|
14
13
|
ChatGptPromptEngine,
|
15
14
|
ClaudePromptEngine,
|
@@ -46,7 +45,7 @@ except ImportError:
|
|
46
45
|
|
47
46
|
ModelType = TypeVar(
|
48
47
|
"ModelType",
|
49
|
-
|
48
|
+
AzureChatOpenAI,
|
50
49
|
HuggingFaceTextGenInference,
|
51
50
|
Bedrock,
|
52
51
|
BedrockChat,
|
@@ -72,7 +71,6 @@ class JanusModel(Runnable, JanusModelProtocol):
|
|
72
71
|
|
73
72
|
load_dotenv()
|
74
73
|
|
75
|
-
|
76
74
|
openai_models = [
|
77
75
|
"gpt-4o",
|
78
76
|
"gpt-4o-mini",
|
@@ -82,11 +80,17 @@ openai_models = [
|
|
82
80
|
"gpt-3.5-turbo",
|
83
81
|
"gpt-3.5-turbo-16k",
|
84
82
|
]
|
83
|
+
azure_models = [
|
84
|
+
"gpt-4o",
|
85
|
+
"gpt-4o-mini",
|
86
|
+
"gpt-3.5-turbo-16k",
|
87
|
+
]
|
85
88
|
claude_models = [
|
86
89
|
"bedrock-claude-v2",
|
87
90
|
"bedrock-claude-instant-v1",
|
88
91
|
"bedrock-claude-haiku",
|
89
92
|
"bedrock-claude-sonnet",
|
93
|
+
"bedrock-claude-sonnet-3.5",
|
90
94
|
]
|
91
95
|
llama2_models = [
|
92
96
|
"bedrock-llama2-70b",
|
@@ -120,18 +124,21 @@ bedrock_models = [
|
|
120
124
|
*cohere_models,
|
121
125
|
*mistral_models,
|
122
126
|
]
|
123
|
-
all_models = [*
|
127
|
+
all_models = [*azure_models, *bedrock_models]
|
124
128
|
|
125
129
|
MODEL_TYPE_CONSTRUCTORS: dict[str, ModelType] = {
|
126
|
-
"OpenAI": ChatOpenAI,
|
130
|
+
# "OpenAI": ChatOpenAI,
|
127
131
|
"HuggingFace": HuggingFaceTextGenInference,
|
132
|
+
"Azure": AzureChatOpenAI,
|
128
133
|
"Bedrock": Bedrock,
|
129
134
|
"BedrockChat": BedrockChat,
|
130
135
|
"HuggingFaceLocal": HuggingFacePipeline,
|
131
136
|
}
|
132
137
|
|
133
|
-
|
134
|
-
|
138
|
+
|
139
|
+
MODEL_PROMPT_ENGINES: dict[str, Callable[..., PromptEngine]] = {
|
140
|
+
# **{m: ChatGptPromptEngine for m in openai_models},
|
141
|
+
**{m: ChatGptPromptEngine for m in azure_models},
|
135
142
|
**{m: ClaudePromptEngine for m in claude_models},
|
136
143
|
**{m: Llama2PromptEngine for m in llama2_models},
|
137
144
|
**{m: Llama3PromptEngine for m in llama3_models},
|
@@ -141,11 +148,13 @@ MODEL_PROMPT_ENGINES: dict[str, type[PromptEngine]] = {
|
|
141
148
|
}
|
142
149
|
|
143
150
|
MODEL_ID_TO_LONG_ID = {
|
144
|
-
**{m: mr for m, mr in openai_model_reroutes.items()},
|
151
|
+
# **{m: mr for m, mr in openai_model_reroutes.items()},
|
152
|
+
**{m: mr for m, mr in azure_model_reroutes.items()},
|
145
153
|
"bedrock-claude-v2": "anthropic.claude-v2",
|
146
154
|
"bedrock-claude-instant-v1": "anthropic.claude-instant-v1",
|
147
155
|
"bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
148
156
|
"bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
157
|
+
"bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
149
158
|
"bedrock-llama2-70b": "meta.llama2-70b-v1",
|
150
159
|
"bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
|
151
160
|
"bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
|
@@ -171,8 +180,9 @@ DEFAULT_MODELS = list(MODEL_DEFAULT_ARGUMENTS.keys())
|
|
171
180
|
|
172
181
|
MODEL_CONFIG_DIR = Path.home().expanduser() / ".janus" / "llm"
|
173
182
|
|
174
|
-
MODEL_TYPES: dict[str,
|
175
|
-
**{m: "OpenAI" for m in openai_models},
|
183
|
+
MODEL_TYPES: dict[str, PromptEngine] = {
|
184
|
+
# **{m: "OpenAI" for m in openai_models},
|
185
|
+
**{m: "Azure" for m in azure_models},
|
176
186
|
**{m: "BedrockChat" for m in bedrock_models},
|
177
187
|
}
|
178
188
|
|
@@ -182,13 +192,17 @@ TOKEN_LIMITS: dict[str, int] = {
|
|
182
192
|
"gpt-4-1106-preview": 128_000,
|
183
193
|
"gpt-4-0125-preview": 128_000,
|
184
194
|
"gpt-4o-2024-05-13": 128_000,
|
195
|
+
"gpt-4o-2024-08-06": 128_000,
|
196
|
+
"gpt-4o-mini": 128_000,
|
185
197
|
"gpt-3.5-turbo-0125": 16_384,
|
198
|
+
"gpt35-turbo-16k": 16_384,
|
186
199
|
"text-embedding-ada-002": 8191,
|
187
200
|
"gpt4all": 16_384,
|
188
201
|
"anthropic.claude-v2": 100_000,
|
189
202
|
"anthropic.claude-instant-v1": 100_000,
|
190
203
|
"anthropic.claude-3-haiku-20240307-v1:0": 248_000,
|
191
204
|
"anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
|
205
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
|
192
206
|
"meta.llama2-70b-v1": 4096,
|
193
207
|
"meta.llama2-70b-chat-v1": 4096,
|
194
208
|
"meta.llama2-13b-chat-v1": 4096,
|
@@ -270,11 +284,21 @@ def load_model(model_id) -> JanusModel:
|
|
270
284
|
openai_api_key=str(os.getenv("OPENAI_API_KEY")),
|
271
285
|
openai_organization=str(os.getenv("OPENAI_ORG_ID")),
|
272
286
|
)
|
273
|
-
log.warning("Do NOT use this model in sensitive environments!")
|
274
|
-
log.warning("If you would like to cancel, please press Ctrl+C.")
|
275
|
-
log.warning("Waiting 10 seconds...")
|
287
|
+
# log.warning("Do NOT use this model in sensitive environments!")
|
288
|
+
# log.warning("If you would like to cancel, please press Ctrl+C.")
|
289
|
+
# log.warning("Waiting 10 seconds...")
|
276
290
|
# Give enough time for the user to read the warnings and cancel
|
277
|
-
time.sleep(10)
|
291
|
+
# time.sleep(10)
|
292
|
+
raise DeprecationWarning("OpenAI models are no longer supported.")
|
293
|
+
|
294
|
+
elif model_type_name == "Azure":
|
295
|
+
model_args.update(
|
296
|
+
{
|
297
|
+
"api_key": os.getenv("AZURE_OPENAI_API_KEY"),
|
298
|
+
"azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
|
299
|
+
"api_version": os.getenv("OPENAI_API_VERSION", "2024-02-01"),
|
300
|
+
}
|
301
|
+
)
|
278
302
|
|
279
303
|
model_type = MODEL_TYPE_CONSTRUCTORS[model_type_name]
|
280
304
|
prompt_engine = MODEL_PROMPT_ENGINES[model_id]
|
@@ -0,0 +1,136 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import uuid
|
4
|
+
|
5
|
+
from langchain.output_parsers import PydanticOutputParser
|
6
|
+
from langchain_core.exceptions import OutputParserException
|
7
|
+
from langchain_core.language_models import BaseLanguageModel
|
8
|
+
from langchain_core.messages import BaseMessage
|
9
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
10
|
+
|
11
|
+
from janus.language.block import CodeBlock
|
12
|
+
from janus.parsers.parser import JanusParser
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
RNG = random.Random()
|
17
|
+
|
18
|
+
|
19
|
+
class PartitionObject(BaseModel):
|
20
|
+
reasoning: str = Field(
|
21
|
+
description="An explanation for why the code should be split at this point"
|
22
|
+
)
|
23
|
+
location: str = Field(
|
24
|
+
description="The 8-character line label which should start a new chunk"
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class PartitionList(BaseModel):
|
29
|
+
__root__: list[PartitionObject] = Field(
|
30
|
+
description=(
|
31
|
+
"A list of appropriate split points, each with a `reasoning` field "
|
32
|
+
"that explains a justification for splitting the code at that point, "
|
33
|
+
"and a `location` field which is simply the 8-character line ID. "
|
34
|
+
"The `reasoning` field should always be included first."
|
35
|
+
)
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
class PartitionParser(JanusParser, PydanticOutputParser):
|
40
|
+
token_limit: int
|
41
|
+
model: BaseLanguageModel
|
42
|
+
lines: list[str] = []
|
43
|
+
line_id_to_index: dict[str, int] = {}
|
44
|
+
|
45
|
+
def __init__(self, token_limit: int, model: BaseLanguageModel):
|
46
|
+
PydanticOutputParser.__init__(
|
47
|
+
self,
|
48
|
+
pydantic_object=PartitionList,
|
49
|
+
model=model,
|
50
|
+
token_limit=token_limit,
|
51
|
+
)
|
52
|
+
|
53
|
+
def parse_input(self, block: CodeBlock) -> str:
|
54
|
+
code = str(block.text)
|
55
|
+
RNG.seed(code)
|
56
|
+
|
57
|
+
self.lines = code.split("\n")
|
58
|
+
|
59
|
+
# Generate a unique ID for each line (ensure they are unique)
|
60
|
+
line_ids = set()
|
61
|
+
while len(line_ids) < len(self.lines):
|
62
|
+
line_ids.add(str(uuid.UUID(int=RNG.getrandbits(128), version=4))[:8])
|
63
|
+
|
64
|
+
# Prepend each line with the corresponding ID, save the mapping
|
65
|
+
self.line_id_to_index = {lid: i for i, lid in enumerate(line_ids)}
|
66
|
+
processed = "\n".join(
|
67
|
+
f"{line_id}\t{self.lines[i]}" for line_id, i in self.line_id_to_index.items()
|
68
|
+
)
|
69
|
+
return processed
|
70
|
+
|
71
|
+
def parse(self, text: str | BaseMessage) -> str:
|
72
|
+
if isinstance(text, BaseMessage):
|
73
|
+
text = str(text.content)
|
74
|
+
|
75
|
+
try:
|
76
|
+
out: PartitionList = super().parse(text)
|
77
|
+
except (OutputParserException, json.JSONDecodeError):
|
78
|
+
log.debug(f"Invalid JSON object. Output:\n{text}")
|
79
|
+
raise
|
80
|
+
|
81
|
+
# Locate any invalid line IDs, raise exception if any found
|
82
|
+
invalid_splits = [
|
83
|
+
partition.location
|
84
|
+
for partition in out.__root__
|
85
|
+
if partition.location not in self.line_id_to_index
|
86
|
+
]
|
87
|
+
if invalid_splits:
|
88
|
+
err_msg = (
|
89
|
+
f"{len(invalid_splits)} line ID(s) not found in input: "
|
90
|
+
+ ", ".join(invalid_splits)
|
91
|
+
)
|
92
|
+
log.warning(err_msg)
|
93
|
+
raise OutputParserException(err_msg)
|
94
|
+
|
95
|
+
# Map line IDs to indices (so they can be sorted and lines indexed)
|
96
|
+
index_to_line_id = {0: "START", None: "END"}
|
97
|
+
split_points = {0}
|
98
|
+
for partition in out.__root__:
|
99
|
+
index = self.line_id_to_index[partition.location]
|
100
|
+
index_to_line_id[index] = partition.location
|
101
|
+
split_points.add(index)
|
102
|
+
|
103
|
+
# Get partition start/ends, chunks, chunk lengths
|
104
|
+
split_points = sorted(split_points) + [None]
|
105
|
+
partition_indices = list(zip(split_points, split_points[1:]))
|
106
|
+
partition_points = [
|
107
|
+
(index_to_line_id[i0], index_to_line_id[i1]) for i0, i1 in partition_indices
|
108
|
+
]
|
109
|
+
chunks = ["\n".join(self.lines[i0:i1]) for i0, i1 in partition_indices]
|
110
|
+
chunk_tokens = list(map(self.model.get_num_tokens, chunks))
|
111
|
+
|
112
|
+
# Collect any chunks that exceed token limit
|
113
|
+
oversized_indices: list[int] = [
|
114
|
+
i for i, n in enumerate(chunk_tokens) if n > self.token_limit
|
115
|
+
]
|
116
|
+
if oversized_indices:
|
117
|
+
data = list(zip(partition_points, chunks, chunk_tokens))
|
118
|
+
data = [data[i] for i in oversized_indices]
|
119
|
+
|
120
|
+
problem_points = "\n".join(
|
121
|
+
[
|
122
|
+
f"{i0} to {i1} ({t / self.token_limit:.1f}x maximum length)"
|
123
|
+
for (i0, i1), _, t in data
|
124
|
+
]
|
125
|
+
)
|
126
|
+
log.warning(f"Found {len(data)} oversized chunks:\n{problem_points}")
|
127
|
+
log.debug(
|
128
|
+
"Oversized chunks:\n"
|
129
|
+
+ "\n#############\n".join(chunk for _, chunk, _ in data)
|
130
|
+
)
|
131
|
+
raise OutputParserException(
|
132
|
+
f"The following segments are too long and must be "
|
133
|
+
f"further subdivided:\n{problem_points}"
|
134
|
+
)
|
135
|
+
|
136
|
+
return "\n<JANUS_PARTITION>\n".join(chunks)
|
janus/refiners/refiner.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Any
|
2
3
|
|
3
4
|
from langchain.output_parsers import RetryWithErrorOutputParser
|
@@ -27,7 +28,7 @@ class JanusRefiner(JanusParser):
|
|
27
28
|
|
28
29
|
class FixParserExceptions(JanusRefiner, RetryWithErrorOutputParser):
|
29
30
|
def __init__(self, llm: JanusModel, parser: JanusParser, max_retries: int):
|
30
|
-
retry_prompt = MODEL_PROMPT_ENGINES[llm.
|
31
|
+
retry_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
31
32
|
source_language="text",
|
32
33
|
prompt_template="refinement/fix_exceptions",
|
33
34
|
).prompt
|
@@ -46,6 +47,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
46
47
|
max_retries: int
|
47
48
|
reflection_chain: RunnableSerializable
|
48
49
|
revision_chain: RunnableSerializable
|
50
|
+
reflection_prompt_name: str
|
49
51
|
|
50
52
|
def __init__(
|
51
53
|
self,
|
@@ -54,11 +56,11 @@ class ReflectionRefiner(JanusRefiner):
|
|
54
56
|
max_retries: int,
|
55
57
|
prompt_template_name: str = "refinement/reflection",
|
56
58
|
):
|
57
|
-
reflection_prompt = MODEL_PROMPT_ENGINES[llm.
|
59
|
+
reflection_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
58
60
|
source_language="text",
|
59
61
|
prompt_template=prompt_template_name,
|
60
62
|
).prompt
|
61
|
-
revision_prompt = MODEL_PROMPT_ENGINES[llm.
|
63
|
+
revision_prompt = MODEL_PROMPT_ENGINES[llm.short_model_id](
|
62
64
|
source_language="text",
|
63
65
|
prompt_template="refinement/revision",
|
64
66
|
).prompt
|
@@ -66,6 +68,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
66
68
|
reflection_chain = reflection_prompt | llm | StrOutputParser()
|
67
69
|
revision_chain = revision_prompt | llm | StrOutputParser()
|
68
70
|
super().__init__(
|
71
|
+
reflection_prompt_name=prompt_template_name,
|
69
72
|
reflection_chain=reflection_chain,
|
70
73
|
revision_chain=revision_chain,
|
71
74
|
parser=parser,
|
@@ -75,6 +78,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
75
78
|
def parse_completion(
|
76
79
|
self, completion: str, prompt_value: PromptValue, **kwargs
|
77
80
|
) -> Any:
|
81
|
+
log.info(f"Reflection Prompt: {self.reflection_prompt_name}")
|
78
82
|
for retry_number in range(self.max_retries):
|
79
83
|
reflection = self.reflection_chain.invoke(
|
80
84
|
dict(
|
@@ -82,7 +86,7 @@ class ReflectionRefiner(JanusRefiner):
|
|
82
86
|
completion=completion,
|
83
87
|
)
|
84
88
|
)
|
85
|
-
if
|
89
|
+
if re.search(r"\bLGTM\b", reflection) is not None:
|
86
90
|
return self.parser.parse(completion)
|
87
91
|
if not retry_number:
|
88
92
|
log.info(f"Completion:\n{completion}")
|
@@ -105,11 +109,3 @@ class HallucinationRefiner(ReflectionRefiner):
|
|
105
109
|
prompt_template_name="refinement/hallucination",
|
106
110
|
**kwargs,
|
107
111
|
)
|
108
|
-
|
109
|
-
|
110
|
-
REFINERS = dict(
|
111
|
-
none=JanusRefiner,
|
112
|
-
parser=FixParserExceptions,
|
113
|
-
reflection=ReflectionRefiner,
|
114
|
-
hallucination=HallucinationRefiner,
|
115
|
-
)
|
janus/refiners/uml.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
from janus.llm.models_info import JanusModel
|
2
|
+
from janus.parsers.parser import JanusParser
|
3
|
+
from janus.refiners.refiner import ReflectionRefiner
|
4
|
+
|
5
|
+
|
6
|
+
class ALCFixUMLVariablesRefiner(ReflectionRefiner):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
llm: JanusModel,
|
10
|
+
parser: JanusParser,
|
11
|
+
max_retries: int,
|
12
|
+
):
|
13
|
+
super().__init__(
|
14
|
+
llm=llm,
|
15
|
+
parser=parser,
|
16
|
+
max_retries=max_retries,
|
17
|
+
prompt_template_name="refinement/uml/alc_fix_variables",
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class FixUMLConnectionsRefiner(ReflectionRefiner):
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
llm: JanusModel,
|
25
|
+
parser: JanusParser,
|
26
|
+
max_retries: int,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
llm=llm,
|
30
|
+
parser=parser,
|
31
|
+
max_retries=max_retries,
|
32
|
+
prompt_template_name="refinement/uml/fix_connections",
|
33
|
+
)
|
janus/retrievers/retriever.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from langchain_core.documents import Document
|
4
|
+
from langchain_core.output_parsers import StrOutputParser
|
1
5
|
from langchain_core.retrievers import BaseRetriever
|
2
6
|
from langchain_core.runnables import Runnable, RunnableConfig
|
3
7
|
|
4
8
|
from janus.language.block import CodeBlock
|
9
|
+
from janus.llm.models_info import MODEL_PROMPT_ENGINES, JanusModel
|
10
|
+
from janus.utils.logger import create_logger
|
11
|
+
from janus.utils.pdf_docs_reader import PDFDocsReader
|
12
|
+
|
13
|
+
log = create_logger(__name__)
|
5
14
|
|
6
15
|
|
7
16
|
class JanusRetriever(Runnable):
|
@@ -40,3 +49,54 @@ class TextSearchRetriever(JanusRetriever):
|
|
40
49
|
docs = self.retriever.invoke(code_block.text)
|
41
50
|
context = "\n\n".join(doc.page_content for doc in docs)
|
42
51
|
return f"You may use the following additional context: {context}"
|
52
|
+
|
53
|
+
|
54
|
+
class LanguageDocsRetriever(JanusRetriever):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
llm: JanusModel,
|
58
|
+
language_name: str,
|
59
|
+
prompt_template_name: str = "retrieval/language_docs",
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.llm: JanusModel = llm
|
63
|
+
self.language: str = language_name
|
64
|
+
|
65
|
+
self.PDF_reader = PDFDocsReader(
|
66
|
+
language=self.language,
|
67
|
+
)
|
68
|
+
|
69
|
+
language_docs_prompt = MODEL_PROMPT_ENGINES[self.llm.short_model_id](
|
70
|
+
source_language=self.language,
|
71
|
+
prompt_template=prompt_template_name,
|
72
|
+
).prompt
|
73
|
+
|
74
|
+
parser: StrOutputParser = StrOutputParser()
|
75
|
+
self.chain = language_docs_prompt | self.llm | parser
|
76
|
+
|
77
|
+
def get_context(self, code_block: CodeBlock) -> str:
|
78
|
+
functionality_to_reference: str = self.chain.invoke(
|
79
|
+
dict({"SOURCE_CODE": code_block.text, "SOURCE_LANGUAGE": self.language})
|
80
|
+
)
|
81
|
+
if functionality_to_reference == "NODOCS":
|
82
|
+
log.debug("No Opcodes requested from language docs retriever.")
|
83
|
+
return ""
|
84
|
+
else:
|
85
|
+
functionality_to_reference: List = functionality_to_reference.split(", ")
|
86
|
+
log.debug(
|
87
|
+
f"List of opcodes requested by language docs retriever"
|
88
|
+
f"to search the {self.language} "
|
89
|
+
f"docs for: {functionality_to_reference}"
|
90
|
+
)
|
91
|
+
|
92
|
+
docs: List[Document] = self.PDF_reader.search_language_reference(
|
93
|
+
functionality_to_reference
|
94
|
+
)
|
95
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
96
|
+
if context:
|
97
|
+
return (
|
98
|
+
f"You may reference the following excerpts from the {self.language} "
|
99
|
+
f"language documentation: {context}"
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
return ""
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import joblib
|
7
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8
|
+
from langchain_core.documents import Document
|
9
|
+
from langchain_unstructured import UnstructuredLoader
|
10
|
+
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
|
11
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
12
|
+
|
13
|
+
from janus.utils.logger import create_logger
|
14
|
+
|
15
|
+
log = create_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PDFDocsReader:
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
language: str,
|
22
|
+
chunk_size: int = 1000,
|
23
|
+
chunk_overlap: int = 100,
|
24
|
+
start_page: Optional[int] = None,
|
25
|
+
end_page: Optional[int] = None,
|
26
|
+
vectorizer: CountVectorizer = TfidfVectorizer(),
|
27
|
+
):
|
28
|
+
self.retrieval_docs_dir: Path = Path(
|
29
|
+
os.getenv("RETRIEVAL_DOCS_DIR", "retrieval_docs")
|
30
|
+
)
|
31
|
+
self.language = language
|
32
|
+
self.chunk_size = chunk_size
|
33
|
+
self.chunk_overlap = chunk_overlap
|
34
|
+
self.start_page = start_page
|
35
|
+
self.end_page = end_page
|
36
|
+
self.vectorizer = vectorizer
|
37
|
+
self.documents = self.load_and_chunk_pdf()
|
38
|
+
self.doc_vectors = self.vectorize_documents()
|
39
|
+
|
40
|
+
def load_and_chunk_pdf(self) -> List[str]:
|
41
|
+
pdf_path = self.retrieval_docs_dir / f"{self.language}.pdf"
|
42
|
+
pickled_documents_path = (
|
43
|
+
self.retrieval_docs_dir / f"{self.language}_documents.pkl"
|
44
|
+
)
|
45
|
+
|
46
|
+
if pickled_documents_path.exists():
|
47
|
+
log.debug(
|
48
|
+
f"Loading pre-chunked PDF from {pickled_documents_path}. "
|
49
|
+
f"If you want to regenerate retrieval docs for {self.language}, "
|
50
|
+
f"delete the file at {pickled_documents_path}, "
|
51
|
+
f"then add a new {self.language}.pdf."
|
52
|
+
)
|
53
|
+
documents = joblib.load(pickled_documents_path)
|
54
|
+
else:
|
55
|
+
if not pdf_path.exists():
|
56
|
+
raise FileNotFoundError(
|
57
|
+
f"Language docs retrieval is enabled, but no PDF for language "
|
58
|
+
f"'{self.language}' was found. Move a "
|
59
|
+
f"{self.language} reference manual to "
|
60
|
+
f"{pdf_path.absolute()} "
|
61
|
+
f"(the path to the directory of PDF docs can be "
|
62
|
+
f"set with the env variable 'RETRIEVAL_DOCS_DIR')."
|
63
|
+
)
|
64
|
+
log.info(
|
65
|
+
f"Chunking reference PDF for {self.language} using unstructured - "
|
66
|
+
f"if your PDF has many pages, this could take a while..."
|
67
|
+
)
|
68
|
+
start_time = time.time()
|
69
|
+
loader = UnstructuredLoader(
|
70
|
+
pdf_path,
|
71
|
+
chunking_strategy="basic",
|
72
|
+
max_characters=1000000,
|
73
|
+
include_orig_elements=False,
|
74
|
+
start_page=self.start_page,
|
75
|
+
end_page=self.end_page,
|
76
|
+
)
|
77
|
+
docs = loader.load()
|
78
|
+
text = "\n\n".join([doc.page_content for doc in docs])
|
79
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
80
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
81
|
+
)
|
82
|
+
documents = text_splitter.split_text(text)
|
83
|
+
log.info(f"Document store created for language: {self.language}")
|
84
|
+
end_time = time.time()
|
85
|
+
log.info(
|
86
|
+
f"Processing time for {self.language} PDF: "
|
87
|
+
f"{end_time - start_time} seconds"
|
88
|
+
)
|
89
|
+
|
90
|
+
joblib.dump(documents, pickled_documents_path)
|
91
|
+
log.debug(f"Documents saved to {pickled_documents_path}")
|
92
|
+
|
93
|
+
return documents
|
94
|
+
|
95
|
+
def vectorize_documents(self) -> (TfidfVectorizer, any):
|
96
|
+
doc_vectors = self.vectorizer.fit_transform(self.documents)
|
97
|
+
return doc_vectors
|
98
|
+
|
99
|
+
def search_language_reference(
|
100
|
+
self,
|
101
|
+
query: List[str],
|
102
|
+
top_k: int = 1,
|
103
|
+
min_similarity: float = 0.1,
|
104
|
+
) -> List[Document]:
|
105
|
+
"""Searches through the vectorized PDF for the query using
|
106
|
+
tf-idf and returns a list of langchain Documents."""
|
107
|
+
|
108
|
+
docs: List[Document] = []
|
109
|
+
|
110
|
+
for item in query:
|
111
|
+
# Transform the query using the TF-IDF vectorizer
|
112
|
+
query_vector = self.vectorizer.transform([item])
|
113
|
+
|
114
|
+
# Calculate cosine similarities between the query and document vectors
|
115
|
+
similarities = cosine_similarity(query_vector, self.doc_vectors).flatten()
|
116
|
+
|
117
|
+
# Get the indices of documents with similarity above the threshold
|
118
|
+
valid_indices = [
|
119
|
+
i for i, sim in enumerate(similarities) if sim >= min_similarity
|
120
|
+
]
|
121
|
+
|
122
|
+
# Sort the valid indices by similarity score in descending order
|
123
|
+
sorted_indices = sorted(
|
124
|
+
valid_indices, key=lambda i: similarities[i], reverse=True
|
125
|
+
)
|
126
|
+
|
127
|
+
# Limit to top-k results
|
128
|
+
top_indices = sorted_indices[:top_k]
|
129
|
+
|
130
|
+
# Retrieve the top-k most relevant documents
|
131
|
+
docs += [Document(page_content=self.documents[i]) for i in top_indices]
|
132
|
+
log.debug(f"Langauge documentation search result: {docs}")
|
133
|
+
|
134
|
+
return docs
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: janus-llm
|
3
|
-
Version: 4.
|
3
|
+
Version: 4.2.0
|
4
4
|
Summary: A transcoding library using LLMs.
|
5
5
|
Home-page: https://github.com/janus-llm/janus-llm
|
6
6
|
License: Apache 2.0
|
@@ -23,20 +23,28 @@ Requires-Dist: langchain-anthropic (>=0.1.15,<0.2.0)
|
|
23
23
|
Requires-Dist: langchain-community (>=0.2.0,<0.3.0)
|
24
24
|
Requires-Dist: langchain-core (>=0.2.0,<0.3.0)
|
25
25
|
Requires-Dist: langchain-openai (>=0.1.8,<0.2.0)
|
26
|
+
Requires-Dist: langchain-unstructured (>=0.1.2,<0.2.0)
|
26
27
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
27
28
|
Requires-Dist: numpy (>=1.24.3,<2.0.0)
|
28
29
|
Requires-Dist: openai (>=1.14.0,<2.0.0)
|
30
|
+
Requires-Dist: pi-heif (>=0.20.0,<0.21.0)
|
29
31
|
Requires-Dist: py-readability-metrics (>=1.4.5,<2.0.0)
|
30
32
|
Requires-Dist: py-rouge (>=1.1,<2.0)
|
33
|
+
Requires-Dist: pytesseract (>=0.3.13,<0.4.0)
|
31
34
|
Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
|
32
35
|
Requires-Dist: rich (>=13.7.1,<14.0.0)
|
33
36
|
Requires-Dist: sacrebleu (>=2.4.1,<3.0.0)
|
37
|
+
Requires-Dist: scikit-learn (>=1.5.2,<2.0.0)
|
34
38
|
Requires-Dist: sentence-transformers (>=2.6.1,<3.0.0) ; extra == "hf-local" or extra == "all"
|
39
|
+
Requires-Dist: tesseract (>=0.1.3,<0.2.0)
|
35
40
|
Requires-Dist: text-generation (>=0.6.0,<0.7.0)
|
36
41
|
Requires-Dist: tiktoken (>=0.7.0,<0.8.0)
|
37
42
|
Requires-Dist: transformers (>=4.31.0,<5.0.0)
|
38
43
|
Requires-Dist: tree-sitter (>=0.21.0,<0.22.0)
|
39
44
|
Requires-Dist: typer (>=0.9.0,<0.10.0)
|
45
|
+
Requires-Dist: unstructured (>=0.15.9,<0.16.0)
|
46
|
+
Requires-Dist: unstructured-inference (>=0.7.36,<0.8.0)
|
47
|
+
Requires-Dist: unstructured-pytesseract (>=0.3.13,<0.4.0)
|
40
48
|
Project-URL: Documentation, https://janus-llm.github.io/janus-llm
|
41
49
|
Project-URL: Repository, https://github.com/janus-llm/janus-llm
|
42
50
|
Description-Content-Type: text/markdown
|