asag-rag 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asag_rag-0.1.0/PKG-INFO +25 -0
- asag_rag-0.1.0/README.md +1 -0
- asag_rag-0.1.0/pyproject.toml +25 -0
- asag_rag-0.1.0/setup.cfg +4 -0
- asag_rag-0.1.0/src/asag_rag.egg-info/PKG-INFO +25 -0
- asag_rag-0.1.0/src/asag_rag.egg-info/SOURCES.txt +33 -0
- asag_rag-0.1.0/src/asag_rag.egg-info/dependency_links.txt +1 -0
- asag_rag-0.1.0/src/asag_rag.egg-info/requires.txt +17 -0
- asag_rag-0.1.0/src/asag_rag.egg-info/top_level.txt +1 -0
- asag_rag-0.1.0/src/rag/__init__.py +6 -0
- asag_rag-0.1.0/src/rag/config.py +15 -0
- asag_rag-0.1.0/src/rag/embed.py +33 -0
- asag_rag-0.1.0/src/rag/generator.py +29 -0
- asag_rag-0.1.0/src/rag/ingest.py +58 -0
- asag_rag-0.1.0/src/rag/ingestions/__init__.py +15 -0
- asag_rag-0.1.0/src/rag/ingestions/base.py +7 -0
- asag_rag-0.1.0/src/rag/ingestions/csv.py +23 -0
- asag_rag-0.1.0/src/rag/ingestions/excel.py +33 -0
- asag_rag-0.1.0/src/rag/ingestions/parquet.py +26 -0
- asag_rag-0.1.0/src/rag/ingestions/rahutomo.py +43 -0
- asag_rag-0.1.0/src/rag/llms/__init__.py +9 -0
- asag_rag-0.1.0/src/rag/llms/base.py +17 -0
- asag_rag-0.1.0/src/rag/llms/gemini.py +32 -0
- asag_rag-0.1.0/src/rag/llms/huggingface.py +21 -0
- asag_rag-0.1.0/src/rag/llms/openai.py +24 -0
- asag_rag-0.1.0/src/rag/pipeline.py +81 -0
- asag_rag-0.1.0/src/rag/prompt.py +61 -0
- asag_rag-0.1.0/src/rag/retriever.py +19 -0
- asag_rag-0.1.0/src/rag/retrievers/__init__.py +9 -0
- asag_rag-0.1.0/src/rag/retrievers/external.py +32 -0
- asag_rag-0.1.0/src/rag/retrievers/hybrid.py +16 -0
- asag_rag-0.1.0/src/rag/retrievers/local.py +53 -0
- asag_rag-0.1.0/src/rag/retrievers/wrappers/__init__.py +1 -0
- asag_rag-0.1.0/src/rag/retrievers/wrappers/top_k.py +24 -0
- asag_rag-0.1.0/src/rag/splitter.py +66 -0
asag_rag-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: asag-rag
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: build>=1.5.0
|
|
8
|
+
Requires-Dist: faiss-cpu>=1.14.2
|
|
9
|
+
Requires-Dist: google-genai>=2.8.0
|
|
10
|
+
Requires-Dist: langchain>=1.3.4
|
|
11
|
+
Requires-Dist: langchain-community>=0.4.2
|
|
12
|
+
Requires-Dist: langchain-huggingface>=1.2.2
|
|
13
|
+
Requires-Dist: langchain-tavily>=0.2.18
|
|
14
|
+
Requires-Dist: openai>=2.41.0
|
|
15
|
+
Requires-Dist: openpyxl>=3.1.5
|
|
16
|
+
Requires-Dist: pandas>=3.0.3
|
|
17
|
+
Requires-Dist: pyarrow>=24.0.0
|
|
18
|
+
Requires-Dist: python-dotenv>=1.2.2
|
|
19
|
+
Requires-Dist: rank-bm25>=0.2.2
|
|
20
|
+
Requires-Dist: scikit-learn>=1.9.0
|
|
21
|
+
Requires-Dist: tavily-python>=0.7.25
|
|
22
|
+
Requires-Dist: transformers==4.57.6
|
|
23
|
+
Requires-Dist: twine>=6.2.0
|
|
24
|
+
|
|
25
|
+
# asag-rag
|
asag_rag-0.1.0/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# asag-rag
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "asag-rag"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"build>=1.5.0",
|
|
9
|
+
"faiss-cpu>=1.14.2",
|
|
10
|
+
"google-genai>=2.8.0",
|
|
11
|
+
"langchain>=1.3.4",
|
|
12
|
+
"langchain-community>=0.4.2",
|
|
13
|
+
"langchain-huggingface>=1.2.2",
|
|
14
|
+
"langchain-tavily>=0.2.18",
|
|
15
|
+
"openai>=2.41.0",
|
|
16
|
+
"openpyxl>=3.1.5",
|
|
17
|
+
"pandas>=3.0.3",
|
|
18
|
+
"pyarrow>=24.0.0",
|
|
19
|
+
"python-dotenv>=1.2.2",
|
|
20
|
+
"rank-bm25>=0.2.2",
|
|
21
|
+
"scikit-learn>=1.9.0",
|
|
22
|
+
"tavily-python>=0.7.25",
|
|
23
|
+
"transformers==4.57.6",
|
|
24
|
+
"twine>=6.2.0",
|
|
25
|
+
]
|
asag_rag-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: asag-rag
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: build>=1.5.0
|
|
8
|
+
Requires-Dist: faiss-cpu>=1.14.2
|
|
9
|
+
Requires-Dist: google-genai>=2.8.0
|
|
10
|
+
Requires-Dist: langchain>=1.3.4
|
|
11
|
+
Requires-Dist: langchain-community>=0.4.2
|
|
12
|
+
Requires-Dist: langchain-huggingface>=1.2.2
|
|
13
|
+
Requires-Dist: langchain-tavily>=0.2.18
|
|
14
|
+
Requires-Dist: openai>=2.41.0
|
|
15
|
+
Requires-Dist: openpyxl>=3.1.5
|
|
16
|
+
Requires-Dist: pandas>=3.0.3
|
|
17
|
+
Requires-Dist: pyarrow>=24.0.0
|
|
18
|
+
Requires-Dist: python-dotenv>=1.2.2
|
|
19
|
+
Requires-Dist: rank-bm25>=0.2.2
|
|
20
|
+
Requires-Dist: scikit-learn>=1.9.0
|
|
21
|
+
Requires-Dist: tavily-python>=0.7.25
|
|
22
|
+
Requires-Dist: transformers==4.57.6
|
|
23
|
+
Requires-Dist: twine>=6.2.0
|
|
24
|
+
|
|
25
|
+
# asag-rag
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/asag_rag.egg-info/PKG-INFO
|
|
4
|
+
src/asag_rag.egg-info/SOURCES.txt
|
|
5
|
+
src/asag_rag.egg-info/dependency_links.txt
|
|
6
|
+
src/asag_rag.egg-info/requires.txt
|
|
7
|
+
src/asag_rag.egg-info/top_level.txt
|
|
8
|
+
src/rag/__init__.py
|
|
9
|
+
src/rag/config.py
|
|
10
|
+
src/rag/embed.py
|
|
11
|
+
src/rag/generator.py
|
|
12
|
+
src/rag/ingest.py
|
|
13
|
+
src/rag/pipeline.py
|
|
14
|
+
src/rag/prompt.py
|
|
15
|
+
src/rag/retriever.py
|
|
16
|
+
src/rag/splitter.py
|
|
17
|
+
src/rag/ingestions/__init__.py
|
|
18
|
+
src/rag/ingestions/base.py
|
|
19
|
+
src/rag/ingestions/csv.py
|
|
20
|
+
src/rag/ingestions/excel.py
|
|
21
|
+
src/rag/ingestions/parquet.py
|
|
22
|
+
src/rag/ingestions/rahutomo.py
|
|
23
|
+
src/rag/llms/__init__.py
|
|
24
|
+
src/rag/llms/base.py
|
|
25
|
+
src/rag/llms/gemini.py
|
|
26
|
+
src/rag/llms/huggingface.py
|
|
27
|
+
src/rag/llms/openai.py
|
|
28
|
+
src/rag/retrievers/__init__.py
|
|
29
|
+
src/rag/retrievers/external.py
|
|
30
|
+
src/rag/retrievers/hybrid.py
|
|
31
|
+
src/rag/retrievers/local.py
|
|
32
|
+
src/rag/retrievers/wrappers/__init__.py
|
|
33
|
+
src/rag/retrievers/wrappers/top_k.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
build>=1.5.0
|
|
2
|
+
faiss-cpu>=1.14.2
|
|
3
|
+
google-genai>=2.8.0
|
|
4
|
+
langchain>=1.3.4
|
|
5
|
+
langchain-community>=0.4.2
|
|
6
|
+
langchain-huggingface>=1.2.2
|
|
7
|
+
langchain-tavily>=0.2.18
|
|
8
|
+
openai>=2.41.0
|
|
9
|
+
openpyxl>=3.1.5
|
|
10
|
+
pandas>=3.0.3
|
|
11
|
+
pyarrow>=24.0.0
|
|
12
|
+
python-dotenv>=1.2.2
|
|
13
|
+
rank-bm25>=0.2.2
|
|
14
|
+
scikit-learn>=1.9.0
|
|
15
|
+
tavily-python>=0.7.25
|
|
16
|
+
transformers==4.57.6
|
|
17
|
+
twine>=6.2.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
rag
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
@dataclass
|
|
4
|
+
class RAGConfig:
|
|
5
|
+
# Chunking
|
|
6
|
+
# chunk_size: int = 512
|
|
7
|
+
# chunk_overlap: int = 64
|
|
8
|
+
|
|
9
|
+
top_k: int = 5
|
|
10
|
+
embedding_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
|
11
|
+
|
|
12
|
+
llm_model: str = "gpt-4o-mini"
|
|
13
|
+
temperature: float = 0.1
|
|
14
|
+
top_p: float = 0.01
|
|
15
|
+
max_tokens: int = 1024
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
|
2
|
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
|
3
|
+
from langchain_community.vectorstores import FAISS
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
from typing import List
|
|
6
|
+
from langchain_core.documents import Document
|
|
7
|
+
import faiss
|
|
8
|
+
|
|
9
|
+
class RAGEmbed:
|
|
10
|
+
def __init__(self, model_name: str):
|
|
11
|
+
self.model = HuggingFaceEmbeddings(model_name=model_name)
|
|
12
|
+
self.vector_store: FAISS | None = None
|
|
13
|
+
|
|
14
|
+
def build_vector_store(self, documents: List[Document]):
|
|
15
|
+
index = faiss.IndexFlatL2(len(self.model.embed_query(documents[0].page_content)))
|
|
16
|
+
self.vector_store = FAISS(
|
|
17
|
+
embedding_function=self.model,
|
|
18
|
+
index=index,
|
|
19
|
+
docstore=InMemoryDocstore(),
|
|
20
|
+
index_to_docstore_id={},
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
def add_documents(self, documents: List[Document]):
|
|
24
|
+
uuids = [str(uuid4()) for _ in range(len(documents))]
|
|
25
|
+
self.vector_store.add_documents(documents=documents, ids=uuids)
|
|
26
|
+
|
|
27
|
+
def load(self, path: str):
|
|
28
|
+
self.vector_store = FAISS.load_local(
|
|
29
|
+
path, self.model, allow_dangerous_deserialization=True
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def save(self, path: str):
|
|
33
|
+
self.vector_store.save_local(path)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Literal, List
|
|
2
|
+
from langchain_core.documents import Document
|
|
3
|
+
from .llms import BACKENDS
|
|
4
|
+
from .llms.base import BaseLLM
|
|
5
|
+
from .prompt import build_prompt, build_prompt_list
|
|
6
|
+
|
|
7
|
+
class RAGGenerator:
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
backend: Literal["huggingface", "openai", "gemini"],
|
|
11
|
+
model_name: str,
|
|
12
|
+
system_prompt: str | None = None,
|
|
13
|
+
**kwargs,
|
|
14
|
+
):
|
|
15
|
+
self.system_prompt = system_prompt
|
|
16
|
+
self.llm: BaseLLM = self._load_backend(backend, model_name, **kwargs)
|
|
17
|
+
|
|
18
|
+
def _load_backend(self, backend: str, model_name: str, **kwargs) -> BaseLLM:
|
|
19
|
+
if backend not in BACKENDS:
|
|
20
|
+
raise ValueError(f"Unknown backend '{backend}'. Choose from: {list(BACKENDS.keys())}")
|
|
21
|
+
return BACKENDS[backend](model_name=model_name, **kwargs)
|
|
22
|
+
|
|
23
|
+
def generate(self, query: str, reference: List[str], docs: List[Document]=[], min_score=0, max_score=5) -> str:
|
|
24
|
+
question, answer = query.split("[CLS]")
|
|
25
|
+
prompt = build_prompt_list(question, answer, reference, docs, self.system_prompt, min_score, max_score)
|
|
26
|
+
return self.llm.generate(prompt)
|
|
27
|
+
|
|
28
|
+
def generate_with_prompt(self, prompt: str) -> str:
|
|
29
|
+
return self.llm.generate(prompt)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
from langchain_core.documents import Document
|
|
4
|
+
from .ingestions import LOADERS, NAME_LOADERS
|
|
5
|
+
from .ingestions import BaseLoader
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
class RAGIngest:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.documents = []
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def _resolve_loader(self, source: str, name: str | None, loader_kwargs: Dict[str, Any] = {}) -> BaseLoader:
|
|
14
|
+
ext = Path(source).suffix.lower()
|
|
15
|
+
if ext not in LOADERS and name not in NAME_LOADERS:
|
|
16
|
+
raise ValueError(f"Unsupported file type '{ext}'. Supported: {list(LOADERS.keys())}")
|
|
17
|
+
if name is not None:
|
|
18
|
+
return NAME_LOADERS[name](**loader_kwargs)
|
|
19
|
+
return LOADERS[ext](**loader_kwargs)
|
|
20
|
+
|
|
21
|
+
def load(
|
|
22
|
+
self,
|
|
23
|
+
source: str,
|
|
24
|
+
name: str = None,
|
|
25
|
+
loader_kwargs: Dict[str, Any] = {},
|
|
26
|
+
):
|
|
27
|
+
loader = self._resolve_loader(source, name, loader_kwargs)
|
|
28
|
+
df = loader.load(source)
|
|
29
|
+
df['filename'] = source.split('\\')[-1]
|
|
30
|
+
return df
|
|
31
|
+
|
|
32
|
+
def merge_df(self, dfs: List[pd.DataFrame]):
|
|
33
|
+
merged = pd.concat(dfs, ignore_index=True)
|
|
34
|
+
return merged
|
|
35
|
+
|
|
36
|
+
def join_df(self, df1, df2, on: List[str], how='inner'):
|
|
37
|
+
if len(on) > 1:
|
|
38
|
+
join_df = pd.merge(df1, df2, how=how, left_on=on[0], right_on=on[1])
|
|
39
|
+
else:
|
|
40
|
+
join_df = pd.merge(df1, df2, how=how, on=on[0])
|
|
41
|
+
return join_df
|
|
42
|
+
|
|
43
|
+
def load_directory(self, directory: str, name: str = None, recursive: bool = True, loader_kwargs: Dict[str, Any] = {}) -> List[pd.DataFrame]:
|
|
44
|
+
"""Load all supported files from a directory."""
|
|
45
|
+
root = Path(directory)
|
|
46
|
+
pattern = "**/*" if recursive else "*"
|
|
47
|
+
all_df = []
|
|
48
|
+
|
|
49
|
+
for path in root.glob(pattern):
|
|
50
|
+
if path.is_file() and path.suffix.lower() in LOADERS:
|
|
51
|
+
df = self.load(str(path), name=name, loader_kwargs=loader_kwargs)
|
|
52
|
+
all_df.append(df)
|
|
53
|
+
return all_df
|
|
54
|
+
|
|
55
|
+
def load_processed(self, dataset_path:str):
|
|
56
|
+
"""Load all processed dataset from a path"""
|
|
57
|
+
|
|
58
|
+
return pd.read_csv(dataset_path)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .base import BaseLoader
|
|
2
|
+
from .rahutomo import RahutomoLoader
|
|
3
|
+
from .excel import ExcelLoader
|
|
4
|
+
from .csv import CSVDataLoader
|
|
5
|
+
from .parquet import ParquetLoader
|
|
6
|
+
|
|
7
|
+
LOADERS = {
|
|
8
|
+
".parquet": ParquetLoader,
|
|
9
|
+
".csv": CSVDataLoader,
|
|
10
|
+
".xlsx": ExcelLoader,
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
NAME_LOADERS = {
|
|
14
|
+
"rahutomo": RahutomoLoader
|
|
15
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from .base import BaseLoader
|
|
4
|
+
|
|
5
|
+
class CSVDataLoader(BaseLoader):
|
|
6
|
+
"""
|
|
7
|
+
Generic Csv loader
|
|
8
|
+
"""
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
usecols: List[str] | None = None,
|
|
12
|
+
sep: str = ","
|
|
13
|
+
):
|
|
14
|
+
self.usecols = usecols
|
|
15
|
+
self.sep = sep
|
|
16
|
+
|
|
17
|
+
def load(self, source: str) -> pd.DataFrame:
|
|
18
|
+
df = pd.read_csv(
|
|
19
|
+
source,
|
|
20
|
+
usecols=self.usecols,
|
|
21
|
+
sep=self.sep
|
|
22
|
+
)
|
|
23
|
+
return df
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import List
|
|
3
|
+
from .base import BaseLoader
|
|
4
|
+
|
|
5
|
+
class ExcelLoader(BaseLoader):
|
|
6
|
+
"""
|
|
7
|
+
Generic Excel loader
|
|
8
|
+
"""
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
sheet_name: str | int | None = 0,
|
|
12
|
+
content_col: str | None = None,
|
|
13
|
+
usecols: List[str] | None = None,
|
|
14
|
+
skiprows: int = 0,
|
|
15
|
+
):
|
|
16
|
+
self.sheet_name = sheet_name
|
|
17
|
+
self.content_col = content_col
|
|
18
|
+
self.usecols = usecols
|
|
19
|
+
self.skiprows = skiprows
|
|
20
|
+
|
|
21
|
+
def load(self, source: str) -> dict[str, pd.DataFrame]:
|
|
22
|
+
raw = pd.read_excel(
|
|
23
|
+
source,
|
|
24
|
+
sheet_name=self.sheet_name,
|
|
25
|
+
usecols=self.usecols,
|
|
26
|
+
skiprows=self.skiprows,
|
|
27
|
+
)
|
|
28
|
+
if isinstance(raw, pd.DataFrame):
|
|
29
|
+
sheets = {"sheet0": raw}
|
|
30
|
+
else:
|
|
31
|
+
sheets = raw
|
|
32
|
+
|
|
33
|
+
return sheets
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import List
|
|
3
|
+
from .base import BaseLoader
|
|
4
|
+
|
|
5
|
+
class ParquetLoader(BaseLoader):
|
|
6
|
+
"""
|
|
7
|
+
Generic Parquet loader
|
|
8
|
+
"""
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
columns: List[str] | None = None,
|
|
12
|
+
filters=None,
|
|
13
|
+
engine: str = "pyarrow",
|
|
14
|
+
):
|
|
15
|
+
self.columns = columns
|
|
16
|
+
self.filters = filters
|
|
17
|
+
self.engine = engine
|
|
18
|
+
|
|
19
|
+
def load(self, source: str) -> pd.DataFrame:
|
|
20
|
+
df = pd.read_parquet(
|
|
21
|
+
source,
|
|
22
|
+
columns=self.columns,
|
|
23
|
+
filters=self.filters,
|
|
24
|
+
engine=self.engine,
|
|
25
|
+
)
|
|
26
|
+
return df
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from .excel import ExcelLoader
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
class RahutomoLoader(ExcelLoader):
|
|
6
|
+
"""
|
|
7
|
+
Domain-specific loader for multi-sheet exam Excel files.
|
|
8
|
+
|
|
9
|
+
Structure expected:
|
|
10
|
+
- Sheet 'Soal' : reference with question text & answer key per row
|
|
11
|
+
- Sheet[1:-1] : one sheet per question with candidate answer rows
|
|
12
|
+
- First column : dropped (row numbering)
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self,
|
|
15
|
+
skiprows: int = 1,
|
|
16
|
+
usecols: List[str] | None = None,):
|
|
17
|
+
super().__init__(None, skiprows=skiprows, usecols=usecols)
|
|
18
|
+
|
|
19
|
+
def load(self, source: str) -> pd.DataFrame:
|
|
20
|
+
all_sheets = super().load(source)
|
|
21
|
+
all_sheets = {k: df.iloc[:, 1:] for k, df in all_sheets.items()}
|
|
22
|
+
|
|
23
|
+
sheet_names = list(all_sheets.keys())
|
|
24
|
+
soal_list = sheet_names[1:-1]
|
|
25
|
+
ref_df = all_sheets["Soal"].reset_index(drop=True)
|
|
26
|
+
|
|
27
|
+
processed = []
|
|
28
|
+
for idx, sheet_name in enumerate(soal_list):
|
|
29
|
+
data = all_sheets[sheet_name].dropna().copy()
|
|
30
|
+
data.columns = [c.strip().replace(" ", "") for c in data.columns]
|
|
31
|
+
|
|
32
|
+
if "o" in data.columns:
|
|
33
|
+
print(f"[warn] Sheet '{sheet_name}' has unexpected column 'o'")
|
|
34
|
+
|
|
35
|
+
ref_row = ref_df.loc[idx]
|
|
36
|
+
soal_col = "Soal" if "Soal" in ref_row.index else "Soal "
|
|
37
|
+
data["Soal"] = ref_row[soal_col]
|
|
38
|
+
data["KunciJawaban"] = ref_row["Kunci Jawaban"]
|
|
39
|
+
data["SheetName"] = sheet_name
|
|
40
|
+
|
|
41
|
+
processed.append(data[["SheetName", "Soal", "KunciJawaban", "Jawaban", "Manual1", "Manual2", "Manual3"]])
|
|
42
|
+
|
|
43
|
+
return pd.concat(processed, ignore_index=True).drop_duplicates()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
class BaseLLM(ABC):
|
|
5
|
+
@abstractmethod
|
|
6
|
+
def generate(self, prompt: str | List[str]) -> str:
|
|
7
|
+
"""
|
|
8
|
+
Sending prompt to LLM to generate response
|
|
9
|
+
|
|
10
|
+
Input:
|
|
11
|
+
- prompt: str | List[str] = single prompt or multiple prompt
|
|
12
|
+
|
|
13
|
+
Output:
|
|
14
|
+
- str = LLM response
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
...
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from .base import BaseLLM
|
|
2
|
+
from google.genai import Client
|
|
3
|
+
from google.genai import types
|
|
4
|
+
from typing import List
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
load_dotenv()
|
|
9
|
+
|
|
10
|
+
class GeminiLLM(BaseLLM):
|
|
11
|
+
"""
|
|
12
|
+
Wrapper Class for Google Gemini
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, model_name: str = "", temperature: float = 0.0, max_tokens: int = 512, top_p: int = 1):
|
|
15
|
+
self.client = Client(api_key=os.environ.get("GOOGLE_API_KEY"))
|
|
16
|
+
self.model_name = model_name
|
|
17
|
+
self.temperature = temperature
|
|
18
|
+
self.max_tokens = max_tokens
|
|
19
|
+
self.top_p = top_p
|
|
20
|
+
|
|
21
|
+
def generate(self, prompt: str | List[str]) -> str:
|
|
22
|
+
output = self.client.models.generate_content(
|
|
23
|
+
model=self.model_name,
|
|
24
|
+
contents=types.Part.from_text(text=prompt),
|
|
25
|
+
config=types.GenerateContentConfig(
|
|
26
|
+
temperature=self.temperature,
|
|
27
|
+
top_p=self.top_p,
|
|
28
|
+
max_output_tokens=self.max_tokens
|
|
29
|
+
),
|
|
30
|
+
)
|
|
31
|
+
return output.text
|
|
32
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .base import BaseLLM
|
|
2
|
+
from transformers import pipeline
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
class HuggingFaceLLM(BaseLLM):
|
|
6
|
+
"""
|
|
7
|
+
Wrapper Class for Huggingface Models
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, model_name: str, max_new_tokens: int = 512, temperature: float = 0.0, device: str = "cpu"):
|
|
10
|
+
self.pipe = pipeline(
|
|
11
|
+
"text-generation",
|
|
12
|
+
model=model_name,
|
|
13
|
+
max_new_tokens=max_new_tokens,
|
|
14
|
+
temperature=temperature,
|
|
15
|
+
device=device,
|
|
16
|
+
do_sample=temperature > 0.0,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def generate(self, prompt: str | List[str]) -> str:
|
|
20
|
+
output = self.pipe(prompt)
|
|
21
|
+
return output[0]["generated_text"][-1]['content']
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from .base import BaseLLM
|
|
2
|
+
from openai import OpenAI
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
class OpenAILLM(BaseLLM):
|
|
6
|
+
"""
|
|
7
|
+
Wrapper Class for OpenAI LLM APIs
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.0, max_tokens: int = 512, top_p: int = 1):
|
|
10
|
+
self.client = OpenAI()
|
|
11
|
+
self.model_name = model_name
|
|
12
|
+
self.temperature = temperature
|
|
13
|
+
self.max_tokens = max_tokens
|
|
14
|
+
self.top_p = top_p
|
|
15
|
+
|
|
16
|
+
def generate(self, prompt: str | List[str]) -> str:
|
|
17
|
+
response = self.client.chat.completions.create(
|
|
18
|
+
model=self.model_name,
|
|
19
|
+
messages=prompt,
|
|
20
|
+
temperature=self.temperature,
|
|
21
|
+
max_tokens=self.max_tokens,
|
|
22
|
+
max_completion_tokens=self.max_tokens
|
|
23
|
+
)
|
|
24
|
+
return response.choices[0].message.content.strip()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from .embed import RAGEmbed
|
|
2
|
+
from .retriever import RAGRetriever
|
|
3
|
+
from .ingest import RAGIngest
|
|
4
|
+
from .config import RAGConfig
|
|
5
|
+
from langchain_core.documents import Document
|
|
6
|
+
from .generator import RAGGenerator
|
|
7
|
+
import numpy as np
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
class RAGPipeline:
|
|
11
|
+
|
|
12
|
+
def __init__(self,
|
|
13
|
+
cfg: RAGConfig,
|
|
14
|
+
embedding_path: str,
|
|
15
|
+
question_column_name: str,
|
|
16
|
+
reference_column_name: str,
|
|
17
|
+
answer_column_name: str,
|
|
18
|
+
):
|
|
19
|
+
self.embedder = RAGEmbed(model_name=cfg.embedding_model)
|
|
20
|
+
self.embedder.load(embedding_path)
|
|
21
|
+
self.ingest = RAGIngest()
|
|
22
|
+
self.local_retriever = RAGRetriever()
|
|
23
|
+
self.question_column_name = question_column_name
|
|
24
|
+
self.reference_column_name = reference_column_name
|
|
25
|
+
self.answer_column_name = answer_column_name
|
|
26
|
+
|
|
27
|
+
system_prompt = """
|
|
28
|
+
Kamu adalah evaluator yang kritis, tegas, dan adil dalam memberikan jawaban dan menyesuaikan dengan fakta dan kriteria yang berlaku. Jika jawaban hanya berisi '-', kosong, atau tidak menjawab pertanyaan, maka berikan skor rendah menyesuaikan apa yang ditulis.
|
|
29
|
+
|
|
30
|
+
/no_think
|
|
31
|
+
"""
|
|
32
|
+
self.generator = RAGGenerator('huggingface', cfg.llm_model, system_prompt=system_prompt, temperature=cfg.temperature, max_new_tokens=cfg.max_tokens)
|
|
33
|
+
|
|
34
|
+
def load_example(self, dataset_path: str):
|
|
35
|
+
example = self.ingest.load_processed(f"{dataset_path}/example.csv")
|
|
36
|
+
self.documents = [Document(page_content = row['input'], metadata={
|
|
37
|
+
'question': row[self.question_column_name],
|
|
38
|
+
'reference_answer': row[self.reference_column_name]
|
|
39
|
+
}) for _, row in example.iterrows()]
|
|
40
|
+
|
|
41
|
+
def retrieve(self, query):
|
|
42
|
+
res = self.local_retriever.retrieve(
|
|
43
|
+
f'{query}','local', {
|
|
44
|
+
"documents": self.documents,
|
|
45
|
+
"vector_store": self.embedder.vector_store
|
|
46
|
+
})
|
|
47
|
+
if len(res) < 1:
|
|
48
|
+
raise ValueError("Context not sufficient")
|
|
49
|
+
return res
|
|
50
|
+
|
|
51
|
+
def run(self, data, output_folder, do_retrieve, batch_num):
|
|
52
|
+
batch_idxs = np.linspace(0, len(data[self.question_column_name]), batch_num + 1, dtype=int)
|
|
53
|
+
for idx in range(3, len(batch_idxs) - 1):
|
|
54
|
+
scores = []
|
|
55
|
+
for i in range(batch_idxs[idx], batch_idxs[idx + 1]):
|
|
56
|
+
docs = []
|
|
57
|
+
row = data.iloc[i]
|
|
58
|
+
if do_retrieve:
|
|
59
|
+
docs = self.retrieve(f"{row[self.question_column_name]}[CLS]{row['input']}")
|
|
60
|
+
|
|
61
|
+
response = self.generator.generate(
|
|
62
|
+
f"{row[self.question_column_name]}[CLS]{row[self.answer_column_name]}",
|
|
63
|
+
[row[self.reference_column_name]],
|
|
64
|
+
docs)
|
|
65
|
+
# time.sleep(10)
|
|
66
|
+
print(response)
|
|
67
|
+
|
|
68
|
+
scores.append(response)
|
|
69
|
+
saved = data.iloc[batch_idxs[idx]:batch_idxs[idx + 1]]
|
|
70
|
+
saved['Prediksi'] = scores
|
|
71
|
+
if do_retrieve:
|
|
72
|
+
saved.to_csv(f"{output_folder}/rag/result-{idx+1}.csv", index=False)
|
|
73
|
+
else:
|
|
74
|
+
saved.to_csv(f"{output_folder}/no-rag/result-{idx+1}.csv", index=False)
|
|
75
|
+
|
|
76
|
+
def run_train(self, dataset_path, output_folder, do_retrieve=True, batch_num:int=10):
|
|
77
|
+
self.load_example(dataset_path)
|
|
78
|
+
train = self.ingest.load_processed(f"{dataset_path}/train.csv")
|
|
79
|
+
self.run(train, output_folder, do_retrieve, batch_num)
|
|
80
|
+
|
|
81
|
+
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from langchain_core.documents import Document
|
|
3
|
+
|
|
4
|
+
DEFAULT_SYSTEM_PROMPT = (
|
|
5
|
+
"You are a helpful assistant. Answer the question using only the provided context. "
|
|
6
|
+
"If the answer is not in the context, say you don't know."
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
def build_prompt_list(question: str, answer: str, reference: List[str], docs: List[Document], system_prompt: str | None = None, min_score=0, max_score=5) -> str:
|
|
10
|
+
context = "\n".join(
|
|
11
|
+
f"[{i+1}] {doc.page_content}" for i, doc in enumerate(docs)
|
|
12
|
+
) if len(docs) > 0 else ""
|
|
13
|
+
references = "\n".join(f"[{i+1}] {ref}" for i, ref in enumerate(reference))
|
|
14
|
+
|
|
15
|
+
system = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
16
|
+
return [{
|
|
17
|
+
"role": "system",
|
|
18
|
+
"content": system
|
|
19
|
+
}, {
|
|
20
|
+
"role": "user",
|
|
21
|
+
"content": (
|
|
22
|
+
f"{f'Context: {context}' if context != "" else ""}"
|
|
23
|
+
f"""
|
|
24
|
+
Given the question and answer data, evaluate the answer and give the score in range {min_score}-{max_score} with the criterias below:
|
|
25
|
+
- Answer cannot be empty or not answering the question. If that is the case, then assign {min_score}.
|
|
26
|
+
- Please tailor the provided answers to the reference as closely as possible
|
|
27
|
+
- If you dont understand about the answer, you can use context and reference to guide you in grading.
|
|
28
|
+
|
|
29
|
+
Question: \"\"\"{question}\"\"\"
|
|
30
|
+
Answer: \"\"\"{answer}\"\"\"
|
|
31
|
+
{"References:\n" + references if len(reference) > 0 else ""}
|
|
32
|
+
|
|
33
|
+
Generate the output with the format below:
|
|
34
|
+
"""
|
|
35
|
+
f"Score: <the answer's score based on reference and context in number format>")
|
|
36
|
+
}]
|
|
37
|
+
|
|
38
|
+
def build_prompt(question: str, answer: str, reference: List[str], docs: List[Document], system_prompt: str | None = None, min_score=0, max_score=5) -> str:
|
|
39
|
+
context = "\n".join(
|
|
40
|
+
f"[{i+1}] {doc.page_content}" for i, doc in enumerate(docs)
|
|
41
|
+
) if len(docs) > 0 else ""
|
|
42
|
+
references = "\n".join(f"[{i+1}] {ref}" for i, ref in enumerate(reference))
|
|
43
|
+
|
|
44
|
+
system = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
45
|
+
return (
|
|
46
|
+
f"{system}\n\n"
|
|
47
|
+
f"{f'Context: {context}' if context != "" else ""}"
|
|
48
|
+
f"""
|
|
49
|
+
Given the question and answer data, evaluate the answer and give the score in range {min_score}-{max_score} with the criterias below:
|
|
50
|
+
- Answer cannot be empty or not answering the question. If that is the case, then assign {min_score}.
|
|
51
|
+
- Please tailor the provided answers to the reference as closely as possible
|
|
52
|
+
- If you dont understand about the answer, you can use context and reference to guide you in grading.
|
|
53
|
+
|
|
54
|
+
Question: \"\"\"{question}\"\"\"
|
|
55
|
+
Answer: \"\"\"{answer}\"\"\"
|
|
56
|
+
{"References:\n" + references if len(reference) > 0 else ""}
|
|
57
|
+
|
|
58
|
+
Generate the output with the format below:
|
|
59
|
+
"""
|
|
60
|
+
f"Score: <the answer's score based on reference and context in number format>"
|
|
61
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
from src.rag.retrievers import SOURCES
|
|
4
|
+
from langchain_core.retrievers import BaseRetriever
|
|
5
|
+
from langchain_core.documents import Document
|
|
6
|
+
|
|
7
|
+
class RAGRetriever:
|
|
8
|
+
def __init__(self):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
def _resolve_loader(self, source: str, loader_kwargs: Dict[str, Any] = {}) -> BaseRetriever:
|
|
12
|
+
if source not in SOURCES:
|
|
13
|
+
raise ValueError(f"Unsupported source '{source}'. Supported: {list(SOURCES.keys())}")
|
|
14
|
+
return SOURCES[source](**loader_kwargs)
|
|
15
|
+
|
|
16
|
+
def retrieve(self, query: str, source: str, loader_kwargs: Dict[str, Any] = {}) -> List[Document]:
|
|
17
|
+
retriever = self._resolve_loader(source, loader_kwargs)
|
|
18
|
+
return retriever.invoke(query)
|
|
19
|
+
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from langchain_core.retrievers import BaseRetriever
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
from langchain_community.retrievers import TavilySearchAPIRetriever
|
|
4
|
+
from langchain_core.documents import Document
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
class ExternalRetriever(BaseRetriever):
|
|
8
|
+
caches:Dict[str, Dict] = {}
|
|
9
|
+
|
|
10
|
+
def __init__(self, k:int=3):
|
|
11
|
+
self.tool = TavilySearchAPIRetriever(k=k)
|
|
12
|
+
|
|
13
|
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
|
14
|
+
|
|
15
|
+
if query in self.caches.keys():
|
|
16
|
+
return self.caches[query]
|
|
17
|
+
|
|
18
|
+
result = self.tool.invoke(query.split("[CLS]")[0])
|
|
19
|
+
self.caches[query] = result
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
|
23
|
+
return self._get_relevant_documents(query)
|
|
24
|
+
|
|
25
|
+
def save_cache(self, path):
|
|
26
|
+
with open(path, 'w') as file:
|
|
27
|
+
json.dump(self.caches, file, indent=4)
|
|
28
|
+
|
|
29
|
+
def load_cache(self, path):
|
|
30
|
+
with open(path, 'r') as file:
|
|
31
|
+
self.caches = json.load(file)
|
|
32
|
+
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from langchain_classic.retrievers import EnsembleRetriever
|
|
3
|
+
from langchain_core.retrievers import BaseRetriever
|
|
4
|
+
from src.rag.retrievers.wrappers import TopKRetriever
|
|
5
|
+
from langchain_core.documents import Document
|
|
6
|
+
|
|
7
|
+
class HybridRetriever(BaseRetriever):
|
|
8
|
+
|
|
9
|
+
def __init__(self, retrievers:List, weights:List):
|
|
10
|
+
self.retriever = TopKRetriever(EnsembleRetriever(retrievers=retrievers, weights=weights))
|
|
11
|
+
|
|
12
|
+
def _get_relevant_documents(self, query) -> List[Document]:
|
|
13
|
+
return self.retriever.invoke(query)
|
|
14
|
+
|
|
15
|
+
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
|
16
|
+
return self._get_relevant_documents(query)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from langchain_community.retrievers import BM25Retriever
|
|
2
|
+
from langchain_classic.retrievers import EnsembleRetriever
|
|
3
|
+
from langchain_community.vectorstores import FAISS
|
|
4
|
+
from langchain_core.documents import Document
|
|
5
|
+
from langchain_core.retrievers import BaseRetriever
|
|
6
|
+
from typing import List
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
|
|
9
|
+
class LocalRetriever(BaseRetriever):
|
|
10
|
+
vector_store: FAISS = Field(...)
|
|
11
|
+
documents: List[Document] = Field(...)
|
|
12
|
+
|
|
13
|
+
def vector_search(self, query: str, question: str, top_n: int = 3):
|
|
14
|
+
retriever = self.vector_store.as_retriever(
|
|
15
|
+
search_kwargs={"filter": {"question": question}}
|
|
16
|
+
)
|
|
17
|
+
return retriever.invoke(query)
|
|
18
|
+
|
|
19
|
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
|
20
|
+
question, answer = query.split('[CLS]')
|
|
21
|
+
return self.ensemble_search(answer, question)
|
|
22
|
+
|
|
23
|
+
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
|
24
|
+
return self._get_relevant_documents(query)
|
|
25
|
+
|
|
26
|
+
def ensemble_search(self, query: str, question: str, alpha: float = 0.5, beta: float = 0.5):
|
|
27
|
+
faiss_retriever = self.vector_store.as_retriever(
|
|
28
|
+
search_kwargs={"filter": {"question": question}}
|
|
29
|
+
)
|
|
30
|
+
bm25_retriever = BM25Retriever.from_documents(
|
|
31
|
+
[d for d in self.documents if d.metadata.get("question") == question]
|
|
32
|
+
)
|
|
33
|
+
ensemble = EnsembleRetriever(
|
|
34
|
+
retrievers=[bm25_retriever, faiss_retriever],
|
|
35
|
+
weights=[alpha, beta]
|
|
36
|
+
)
|
|
37
|
+
return ensemble.invoke(query)
|
|
38
|
+
|
|
39
|
+
def get_ensemble_retriever(self, question, top_n, alpha=0.5, beta=0.5):
|
|
40
|
+
self.faiss_retriever = self.vector_store.as_retriever(
|
|
41
|
+
search_kwargs={
|
|
42
|
+
"k": top_n,
|
|
43
|
+
"filter": {
|
|
44
|
+
'question': question
|
|
45
|
+
}
|
|
46
|
+
},
|
|
47
|
+
)
|
|
48
|
+
self.bm25_retriever = BM25Retriever.from_documents([d for d in self.documents if d.metadata.get('question') == question])
|
|
49
|
+
ensemble_retriever = EnsembleRetriever(
|
|
50
|
+
retrievers=[self.bm25_retriever, self.faiss_retriever],
|
|
51
|
+
weights=[alpha, beta]
|
|
52
|
+
)
|
|
53
|
+
return ensemble_retriever
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .top_k import TopKRetriever
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from langchain_core.retrievers import BaseRetriever
|
|
2
|
+
from langchain_core.documents import Document
|
|
3
|
+
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
4
|
+
from typing import List, Any
|
|
5
|
+
|
|
6
|
+
class TopKRetriever(BaseRetriever):
|
|
7
|
+
"""Get top 3 for any retriever."""
|
|
8
|
+
|
|
9
|
+
retriever: Any
|
|
10
|
+
top_k: int = 3
|
|
11
|
+
|
|
12
|
+
class Config:
|
|
13
|
+
arbitrary_types_allowed = True
|
|
14
|
+
|
|
15
|
+
def _get_relevant_documents(
|
|
16
|
+
self,
|
|
17
|
+
query: str,
|
|
18
|
+
*,
|
|
19
|
+
run_manager: CallbackManagerForRetrieverRun = None
|
|
20
|
+
) -> List[Document]:
|
|
21
|
+
return self.retriever.invoke(query)[:self.top_k]
|
|
22
|
+
|
|
23
|
+
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
|
24
|
+
return self._get_relevant_documents(query)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
class Splitter:
|
|
5
|
+
"""
|
|
6
|
+
Wrapper class to split dataset into train and test
|
|
7
|
+
"""
|
|
8
|
+
def __init__(self, test_size: float, temp_size: float=0):
|
|
9
|
+
self.test_size = test_size
|
|
10
|
+
self.temp_size = temp_size
|
|
11
|
+
|
|
12
|
+
def split_many(self, df_list: List[pd.DataFrame], split_column):
|
|
13
|
+
train_df = []
|
|
14
|
+
test_df = []
|
|
15
|
+
val_df = []
|
|
16
|
+
for df in df_list:
|
|
17
|
+
dataset = self.split(df, split_column)
|
|
18
|
+
if len(dataset) <= 2:
|
|
19
|
+
train_df += dataset[0]
|
|
20
|
+
test_df += dataset[1]
|
|
21
|
+
else:
|
|
22
|
+
train_df += dataset[0]
|
|
23
|
+
val_df += dataset[1]
|
|
24
|
+
test_df += dataset[2]
|
|
25
|
+
if self.temp_size > 0:
|
|
26
|
+
return pd.concat(train_df, ignore_index=True), pd.concat(val_df, ignore_index=True), pd.concat(test_df, ignore_index=True)
|
|
27
|
+
else:
|
|
28
|
+
return pd.concat(train_df, ignore_index=True), pd.concat(test_df, ignore_index=True)
|
|
29
|
+
|
|
30
|
+
def split(self, data: pd.DataFrame, split_column: str):
|
|
31
|
+
if self.temp_size > 0:
|
|
32
|
+
train_df = []
|
|
33
|
+
test_df = []
|
|
34
|
+
val_df = []
|
|
35
|
+
|
|
36
|
+
keys = data[split_column].unique()
|
|
37
|
+
for key in keys:
|
|
38
|
+
tobe_split = data[data[split_column] == key]
|
|
39
|
+
tobe_split = tobe_split.sample(frac=1, random_state=42).reset_index(drop=True)
|
|
40
|
+
idx = -1 * int(self.test_size * len(tobe_split))
|
|
41
|
+
temp_data = tobe_split[idx:]
|
|
42
|
+
test_data = temp_data[:int(self.temp_size * len(temp_data))]
|
|
43
|
+
val_data = temp_data[int(self.temp_size * len(temp_data)):]
|
|
44
|
+
|
|
45
|
+
train_data = tobe_split[:idx]
|
|
46
|
+
|
|
47
|
+
train_df.append(train_data)
|
|
48
|
+
test_df.append(test_data)
|
|
49
|
+
val_df.append(val_data)
|
|
50
|
+
|
|
51
|
+
return train_df, val_df, test_df
|
|
52
|
+
else:
|
|
53
|
+
train_df = []
|
|
54
|
+
test_df = []
|
|
55
|
+
|
|
56
|
+
keys = data[split_column].unique()
|
|
57
|
+
for key in keys:
|
|
58
|
+
tobe_split = data[data[split_column] == key]
|
|
59
|
+
tobe_split = tobe_split.sample(frac=1, random_state=42).reset_index(drop=True)
|
|
60
|
+
idx = -1 * int(self.test_size * len(tobe_split))
|
|
61
|
+
test_data = tobe_split[idx:]
|
|
62
|
+
train_data = tobe_split[:idx]
|
|
63
|
+
train_df.append(train_data)
|
|
64
|
+
test_df.append(test_data)
|
|
65
|
+
|
|
66
|
+
return train_df, test_df
|