llmflowstack 1.1.3__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/PKG-INFO +1 -2
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/__init__.py +8 -8
- llmflowstack-1.1.3/llmflowstack/base/base.py → llmflowstack-1.2.0/llmflowstack/decoders/BaseDecoder.py +27 -56
- {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/GPT_OSS.py +14 -12
- {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/Gemma.py +12 -13
- {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/LLaMA3.py +9 -11
- {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/LLaMA4.py +15 -15
- {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/MedGemma.py +8 -10
- llmflowstack-1.2.0/llmflowstack/rag/VectorDatabase.py +278 -0
- llmflowstack-1.2.0/llmflowstack/rag/__init__.py +5 -0
- llmflowstack-1.2.0/llmflowstack/utils/logging.py +8 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/pyproject.toml +1 -2
- llmflowstack-1.1.3/llmflowstack/callbacks/__init__.py +0 -0
- llmflowstack-1.1.3/llmflowstack/rag/__iinit__.py +0 -5
- llmflowstack-1.1.3/llmflowstack/rag/pipeline.py +0 -279
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/.github/workflows/python-publish.yml +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/.gitignore +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/LICENSE +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/README.md +0 -0
- {llmflowstack-1.1.3/llmflowstack/base → llmflowstack-1.2.0/llmflowstack/callbacks}/__init__.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/callbacks/log_collector.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/callbacks/stop_on_token.py +0 -0
- {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/__init__.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/schemas/__init__.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/schemas/params.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/__init__.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/evaluation_methods.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/exceptions.py +0 -0
- {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/generation_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: llmflowstack
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
|
|
5
5
|
Author-email: Gustavo Henrique Ferreira Cruz <gustavohferreiracruz@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -10,7 +10,6 @@ Requires-Dist: accelerate
|
|
|
10
10
|
Requires-Dist: bert-score
|
|
11
11
|
Requires-Dist: bitsandbytes
|
|
12
12
|
Requires-Dist: chromadb
|
|
13
|
-
Requires-Dist: colorama
|
|
14
13
|
Requires-Dist: datasets
|
|
15
14
|
Requires-Dist: evaluate
|
|
16
15
|
Requires-Dist: huggingface-hub
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .
|
|
3
|
-
from .
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
6
|
-
from .rag
|
|
1
|
+
from .decoders.Gemma import Gemma3
|
|
2
|
+
from .decoders.GPT_OSS import GPT_OSS
|
|
3
|
+
from .decoders.LLaMA3 import LLaMA3
|
|
4
|
+
from .decoders.LLaMA4 import LLaMA4
|
|
5
|
+
from .decoders.MedGemma import MedGemma
|
|
6
|
+
from .rag import VectorDatabase
|
|
7
7
|
from .schemas.params import (GenerationBeamsParams, GenerationParams,
|
|
8
8
|
GenerationSampleParams, TrainParams)
|
|
9
9
|
from .utils.evaluation_methods import text_evaluation
|
|
@@ -14,10 +14,10 @@ __all__ = [
|
|
|
14
14
|
"LLaMA3",
|
|
15
15
|
"LLaMA4",
|
|
16
16
|
"MedGemma",
|
|
17
|
-
"RAGPipeline",
|
|
18
17
|
"GenerationBeamsParams",
|
|
19
18
|
"GenerationParams",
|
|
20
19
|
"GenerationSampleParams",
|
|
21
20
|
"TrainParams",
|
|
22
|
-
"text_evaluation"
|
|
21
|
+
"text_evaluation",
|
|
22
|
+
"VectorDatabase"
|
|
23
23
|
]
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
import gc
|
|
2
2
|
import json
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
4
|
import random
|
|
6
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from logging import getLogger
|
|
7
7
|
from typing import Any, Literal, cast
|
|
8
8
|
from uuid import uuid4
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
|
-
from colorama import Fore, Style, init
|
|
13
12
|
from datasets import Dataset
|
|
14
13
|
from torch import Tensor
|
|
15
14
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
@@ -20,15 +19,15 @@ from trl.trainer.sft_trainer import SFTTrainer
|
|
|
20
19
|
from llmflowstack.callbacks.log_collector import LogCollectorCallback
|
|
21
20
|
from llmflowstack.schemas.params import GenerationParams, TrainParams
|
|
22
21
|
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
22
|
+
from llmflowstack.utils.logging import LogLevel
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class
|
|
25
|
+
class BaseDecoder(ABC):
|
|
26
26
|
model = None
|
|
27
27
|
tokenizer = None
|
|
28
28
|
_model_id = None
|
|
29
29
|
model_is_quantized = None
|
|
30
30
|
seed = None
|
|
31
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO"
|
|
32
31
|
stop_token_ids = []
|
|
33
32
|
question_fields = []
|
|
34
33
|
answer_fields = []
|
|
@@ -37,20 +36,17 @@ class BaseModel(ABC):
|
|
|
37
36
|
self,
|
|
38
37
|
checkpoint: str | None = None,
|
|
39
38
|
quantization: Literal["4bit", "8bit"] | bool | None = None,
|
|
40
|
-
seed: int | None = None
|
|
41
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
39
|
+
seed: int | None = None
|
|
42
40
|
) -> None:
|
|
43
41
|
if not self.question_fields or not self.answer_fields:
|
|
44
42
|
raise NotImplementedError("Subclasses must define question_fields and answer_fields.")
|
|
45
43
|
|
|
46
|
-
init(autoreset=True)
|
|
47
44
|
if seed:
|
|
48
45
|
self._set_seed(seed)
|
|
49
46
|
|
|
50
47
|
self._base_model = checkpoint
|
|
51
48
|
|
|
52
|
-
self.
|
|
53
|
-
self.log_level = log_level
|
|
49
|
+
self.logger = getLogger(f"LLMFlowStack.{self.__class__.__name__}")
|
|
54
50
|
|
|
55
51
|
self.tokenizer: PreTrainedTokenizerBase | None = None
|
|
56
52
|
|
|
@@ -61,6 +57,17 @@ class BaseModel(ABC):
|
|
|
61
57
|
quantization=quantization
|
|
62
58
|
)
|
|
63
59
|
|
|
60
|
+
def _log(
|
|
61
|
+
self,
|
|
62
|
+
message: str,
|
|
63
|
+
level: LogLevel = LogLevel.INFO,
|
|
64
|
+
) -> None:
|
|
65
|
+
log_func = getattr(self.logger, level.lower(), None)
|
|
66
|
+
if log_func:
|
|
67
|
+
log_func(message)
|
|
68
|
+
else:
|
|
69
|
+
self.logger.info(message)
|
|
70
|
+
|
|
64
71
|
@abstractmethod
|
|
65
72
|
def _load_model(
|
|
66
73
|
self,
|
|
@@ -84,7 +91,7 @@ class BaseModel(ABC):
|
|
|
84
91
|
quantization: Any
|
|
85
92
|
) -> None:
|
|
86
93
|
if self.model:
|
|
87
|
-
self._log("A model is already loaded. Attempting to reset it.",
|
|
94
|
+
self._log("A model is already loaded. Attempting to reset it.", LogLevel.WARNING)
|
|
88
95
|
self.unload_model()
|
|
89
96
|
|
|
90
97
|
self._log(f"Loading model on '{checkpoint}'")
|
|
@@ -132,42 +139,6 @@ class BaseModel(ABC):
|
|
|
132
139
|
) -> None:
|
|
133
140
|
self._model_id = uuid4()
|
|
134
141
|
|
|
135
|
-
def _set_logger(
|
|
136
|
-
self,
|
|
137
|
-
level: str
|
|
138
|
-
) -> None:
|
|
139
|
-
level_map = {
|
|
140
|
-
"DEBUG": logging.DEBUG,
|
|
141
|
-
"INFO": logging.INFO,
|
|
142
|
-
"WARNING": logging.WARNING,
|
|
143
|
-
"ERROR": logging.ERROR,
|
|
144
|
-
}
|
|
145
|
-
numeric_level = level_map.get(level.upper(), logging.INFO)
|
|
146
|
-
|
|
147
|
-
logging.basicConfig(
|
|
148
|
-
level=numeric_level,
|
|
149
|
-
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
150
|
-
)
|
|
151
|
-
self.logger = logging.getLogger(__name__)
|
|
152
|
-
|
|
153
|
-
def _log(
|
|
154
|
-
self,
|
|
155
|
-
info: str,
|
|
156
|
-
level: Literal["INFO", "WARNING", "ERROR", "DEBUG"] = "INFO"
|
|
157
|
-
) -> None:
|
|
158
|
-
if level == "INFO":
|
|
159
|
-
colored_msg = f"{Fore.GREEN}{info}{Style.RESET_ALL}"
|
|
160
|
-
self.logger.info(colored_msg)
|
|
161
|
-
elif level == "WARNING":
|
|
162
|
-
colored_msg = f"{Fore.YELLOW}{info}{Style.RESET_ALL}"
|
|
163
|
-
self.logger.warning(colored_msg)
|
|
164
|
-
elif level == "ERROR":
|
|
165
|
-
colored_msg = f"{Fore.RED}{info}{Style.RESET_ALL}"
|
|
166
|
-
self.logger.error(colored_msg)
|
|
167
|
-
elif level == "DEBUG":
|
|
168
|
-
colored_msg = f"{Fore.BLUE}{info}{Style.RESET_ALL}"
|
|
169
|
-
self.logger.debug(colored_msg)
|
|
170
|
-
|
|
171
142
|
def _set_seed(
|
|
172
143
|
self,
|
|
173
144
|
seed: int
|
|
@@ -190,10 +161,10 @@ class BaseModel(ABC):
|
|
|
190
161
|
path: str
|
|
191
162
|
) -> None:
|
|
192
163
|
if not self.model:
|
|
193
|
-
self._log("No model to save.",
|
|
164
|
+
self._log("No model to save.", LogLevel.WARNING)
|
|
194
165
|
return None
|
|
195
166
|
if not self.tokenizer:
|
|
196
|
-
self._log("No tokenizer to save.",
|
|
167
|
+
self._log("No tokenizer to save.", LogLevel.WARNING)
|
|
197
168
|
return None
|
|
198
169
|
|
|
199
170
|
os.makedirs(path, exist_ok=True)
|
|
@@ -299,16 +270,16 @@ class BaseModel(ABC):
|
|
|
299
270
|
save_path: str | None = None
|
|
300
271
|
) -> None:
|
|
301
272
|
if not self.model:
|
|
302
|
-
self._log("Could not find a model loaded. Try loading a model first.",
|
|
273
|
+
self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
|
|
303
274
|
return None
|
|
304
275
|
if not self.tokenizer:
|
|
305
|
-
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.",
|
|
276
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
|
|
306
277
|
return None
|
|
307
278
|
|
|
308
279
|
self._log("Starting DAPT")
|
|
309
280
|
|
|
310
281
|
if self.model_is_quantized:
|
|
311
|
-
self._log("Cannot DAPT a quantized model.",
|
|
282
|
+
self._log("Cannot DAPT a quantized model.", LogLevel.WARNING)
|
|
312
283
|
return None
|
|
313
284
|
|
|
314
285
|
if params is None:
|
|
@@ -443,16 +414,16 @@ class BaseModel(ABC):
|
|
|
443
414
|
save_path: str | None = None
|
|
444
415
|
) -> None:
|
|
445
416
|
if not self.model:
|
|
446
|
-
self._log("Could not find a model loaded. Try loading a model first.",
|
|
417
|
+
self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
|
|
447
418
|
return None
|
|
448
419
|
if not self.tokenizer:
|
|
449
|
-
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.",
|
|
420
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
|
|
450
421
|
return None
|
|
451
422
|
|
|
452
423
|
self._log("Starting fine-tune")
|
|
453
424
|
|
|
454
425
|
if self.model_is_quantized:
|
|
455
|
-
self._log("Cannot fine-tune a quantized model.",
|
|
426
|
+
self._log("Cannot fine-tune a quantized model.", LogLevel.WARNING)
|
|
456
427
|
return None
|
|
457
428
|
|
|
458
429
|
if params is None:
|
|
@@ -521,8 +492,8 @@ class BaseModel(ABC):
|
|
|
521
492
|
self._model_id = None
|
|
522
493
|
self._log("Reset successfully.")
|
|
523
494
|
except Exception as e:
|
|
524
|
-
self._log("Couldn't reset model...",
|
|
525
|
-
self._log(f"{str(e)}",
|
|
495
|
+
self._log("Couldn't reset model...", LogLevel.ERROR)
|
|
496
|
+
self._log(f"{str(e)}", LogLevel.DEBUG)
|
|
526
497
|
|
|
527
498
|
def set_seed(self, seed: int) -> None:
|
|
528
499
|
self._log(f"Setting seed value {seed}")
|
{llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/GPT_OSS.py
RENAMED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import textwrap
|
|
2
1
|
import threading
|
|
3
2
|
from functools import partial
|
|
4
3
|
from time import time
|
|
@@ -11,11 +10,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
|
|
|
11
10
|
from transformers.models.gpt_oss import GptOssForCausalLM
|
|
12
11
|
from transformers.utils.quantization_config import Mxfp4Config
|
|
13
12
|
|
|
14
|
-
from llmflowstack.base.base import BaseModel
|
|
15
13
|
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
14
|
+
from llmflowstack.decoders.BaseDecoder import BaseDecoder
|
|
16
15
|
from llmflowstack.schemas.params import GenerationParams
|
|
17
16
|
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
18
17
|
from llmflowstack.utils.generation_utils import create_generation_params
|
|
18
|
+
from llmflowstack.utils.logging import LogLevel
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class GPTOSSInput(TypedDict):
|
|
@@ -26,7 +26,7 @@ class GPTOSSInput(TypedDict):
|
|
|
26
26
|
reasoning_message: str | None
|
|
27
27
|
reasoning_level: Literal["Low", "Medium", "High"] | None
|
|
28
28
|
|
|
29
|
-
class GPT_OSS(
|
|
29
|
+
class GPT_OSS(BaseDecoder):
|
|
30
30
|
model: GptOssForCausalLM | None = None
|
|
31
31
|
reasoning_level: Literal["Low", "Medium", "High"] = "Low"
|
|
32
32
|
question_fields = ["input_text", "developer_message", "system_message"]
|
|
@@ -36,14 +36,12 @@ class GPT_OSS(BaseModel):
|
|
|
36
36
|
self,
|
|
37
37
|
checkpoint: str | None = None,
|
|
38
38
|
quantization: bool | None = None,
|
|
39
|
-
seed: int | None = None
|
|
40
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
39
|
+
seed: int | None = None
|
|
41
40
|
) -> None:
|
|
42
41
|
return super().__init__(
|
|
43
42
|
checkpoint=checkpoint,
|
|
44
43
|
quantization=quantization,
|
|
45
|
-
seed=seed
|
|
46
|
-
log_level=log_level
|
|
44
|
+
seed=seed
|
|
47
45
|
)
|
|
48
46
|
|
|
49
47
|
def _set_generation_stopping_tokens(
|
|
@@ -51,7 +49,7 @@ class GPT_OSS(BaseModel):
|
|
|
51
49
|
tokens: list[int]
|
|
52
50
|
) -> None:
|
|
53
51
|
if not self.tokenizer:
|
|
54
|
-
self._log("Could not set stop tokens - generation may not work...",
|
|
52
|
+
self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
|
|
55
53
|
return None
|
|
56
54
|
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
|
57
55
|
particular_tokens = encoding.stop_tokens_for_assistant_actions()
|
|
@@ -76,7 +74,7 @@ class GPT_OSS(BaseModel):
|
|
|
76
74
|
attn_implementation="eager",
|
|
77
75
|
)
|
|
78
76
|
except Exception as _:
|
|
79
|
-
self._log("Error trying to load the model. Defaulting to load without quantization...",
|
|
77
|
+
self._log("Error trying to load the model. Defaulting to load without quantization...", LogLevel.WARNING)
|
|
80
78
|
self.model = GptOssForCausalLM.from_pretrained(
|
|
81
79
|
checkpoint,
|
|
82
80
|
dtype="auto",
|
|
@@ -119,7 +117,11 @@ class GPT_OSS(BaseModel):
|
|
|
119
117
|
if expected_answer:
|
|
120
118
|
assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
|
|
121
119
|
|
|
122
|
-
return
|
|
120
|
+
return (
|
|
121
|
+
f"{system_text}{developer_text}"
|
|
122
|
+
f"<|start|>user<|message|>{data["input_text"]}<|end|>"
|
|
123
|
+
f"{assistant_text}"
|
|
124
|
+
)
|
|
123
125
|
|
|
124
126
|
def build_input(
|
|
125
127
|
self,
|
|
@@ -154,7 +156,7 @@ class GPT_OSS(BaseModel):
|
|
|
154
156
|
params: GenerationParams | None = None
|
|
155
157
|
) -> str | None:
|
|
156
158
|
if self.model is None or self.tokenizer is None:
|
|
157
|
-
self._log("Model or Tokenizer missing",
|
|
159
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
158
160
|
return None
|
|
159
161
|
|
|
160
162
|
self._log(f"Processing received input...'")
|
|
@@ -222,7 +224,7 @@ class GPT_OSS(BaseModel):
|
|
|
222
224
|
params: GenerationParams | None = None
|
|
223
225
|
) -> Iterator[str]:
|
|
224
226
|
if self.model is None or self.tokenizer is None:
|
|
225
|
-
self._log("Model or Tokenizer missing",
|
|
227
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
226
228
|
if False:
|
|
227
229
|
yield ""
|
|
228
230
|
return
|
{llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/Gemma.py
RENAMED
|
@@ -10,12 +10,13 @@ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
|
|
|
10
10
|
from transformers.models.gemma3 import Gemma3ForCausalLM
|
|
11
11
|
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
12
12
|
|
|
13
|
-
from llmflowstack.base.base import BaseModel
|
|
14
13
|
from llmflowstack.callbacks.log_collector import LogCollectorCallback
|
|
15
14
|
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
15
|
+
from llmflowstack.decoders.BaseDecoder import BaseDecoder
|
|
16
16
|
from llmflowstack.schemas.params import GenerationParams, TrainParams
|
|
17
17
|
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
18
18
|
from llmflowstack.utils.generation_utils import create_generation_params
|
|
19
|
+
from llmflowstack.utils.logging import LogLevel
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class Gemma3Input(TypedDict):
|
|
@@ -24,7 +25,7 @@ class Gemma3Input(TypedDict):
|
|
|
24
25
|
system_message: str | None
|
|
25
26
|
image_paths: list[str] | None
|
|
26
27
|
|
|
27
|
-
class Gemma3(
|
|
28
|
+
class Gemma3(BaseDecoder):
|
|
28
29
|
model: Gemma3ForCausalLM | None = None
|
|
29
30
|
question_fields = ["input_text", "system_message"]
|
|
30
31
|
answer_fields = ["expected_answer"]
|
|
@@ -33,14 +34,12 @@ class Gemma3(BaseModel):
|
|
|
33
34
|
self,
|
|
34
35
|
checkpoint: str | None = None,
|
|
35
36
|
quantization: Literal["4bit"] | None = None,
|
|
36
|
-
seed: int | None = None
|
|
37
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
37
|
+
seed: int | None = None
|
|
38
38
|
) -> None:
|
|
39
39
|
return super().__init__(
|
|
40
40
|
checkpoint=checkpoint,
|
|
41
41
|
quantization=quantization,
|
|
42
|
-
seed=seed
|
|
43
|
-
log_level=log_level
|
|
42
|
+
seed=seed
|
|
44
43
|
)
|
|
45
44
|
|
|
46
45
|
def _set_generation_stopping_tokens(
|
|
@@ -48,7 +47,7 @@ class Gemma3(BaseModel):
|
|
|
48
47
|
tokens: list[int]
|
|
49
48
|
) -> None:
|
|
50
49
|
if not self.tokenizer:
|
|
51
|
-
self._log("Could not set stop tokens - generation may not work...",
|
|
50
|
+
self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
|
|
52
51
|
return None
|
|
53
52
|
particular_tokens = self.tokenizer.encode("<end_of_turn>")
|
|
54
53
|
self.stop_token_ids = tokens + particular_tokens
|
|
@@ -129,16 +128,16 @@ class Gemma3(BaseModel):
|
|
|
129
128
|
save_path: str | None = None
|
|
130
129
|
) -> None:
|
|
131
130
|
if not self.model:
|
|
132
|
-
self._log("Could not find a model loaded. Try loading a model first.",
|
|
131
|
+
self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
|
|
133
132
|
return None
|
|
134
133
|
if not self.tokenizer:
|
|
135
|
-
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.",
|
|
134
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
|
|
136
135
|
return None
|
|
137
136
|
|
|
138
137
|
self._log("Starting Training")
|
|
139
138
|
|
|
140
139
|
if self.model_is_quantized:
|
|
141
|
-
self._log("Cannot traub a quantized model.",
|
|
140
|
+
self._log("Cannot traub a quantized model.", LogLevel.WARNING)
|
|
142
141
|
return None
|
|
143
142
|
|
|
144
143
|
if params is None:
|
|
@@ -195,7 +194,7 @@ class Gemma3(BaseModel):
|
|
|
195
194
|
save_at_end = True,
|
|
196
195
|
save_path: str | None = None
|
|
197
196
|
) -> None:
|
|
198
|
-
self._log("Only 'dapt' method is available for this class. Redirecting call to it.",
|
|
197
|
+
self._log("Only 'dapt' method is available for this class. Redirecting call to it.", LogLevel.WARNING)
|
|
199
198
|
return self.dapt(
|
|
200
199
|
train_dataset=train_dataset,
|
|
201
200
|
params=params,
|
|
@@ -210,7 +209,7 @@ class Gemma3(BaseModel):
|
|
|
210
209
|
params: GenerationParams | None = None,
|
|
211
210
|
) -> str | None:
|
|
212
211
|
if self.model is None or self.tokenizer is None:
|
|
213
|
-
self._log("Model or Tokenizer missing",
|
|
212
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
214
213
|
return None
|
|
215
214
|
|
|
216
215
|
self._log(f"Processing received input...'")
|
|
@@ -267,7 +266,7 @@ class Gemma3(BaseModel):
|
|
|
267
266
|
params: GenerationParams | None = None
|
|
268
267
|
) -> Iterator[str]:
|
|
269
268
|
if self.model is None or self.tokenizer is None:
|
|
270
|
-
self._log("Model or Tokenizer missing",
|
|
269
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
271
270
|
if False:
|
|
272
271
|
yield ""
|
|
273
272
|
return
|
{llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/LLaMA3.py
RENAMED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import textwrap
|
|
2
1
|
import threading
|
|
3
2
|
from time import time
|
|
4
3
|
from typing import Iterator, Literal, TypedDict, cast
|
|
@@ -9,11 +8,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
|
|
|
9
8
|
from transformers.models.llama import LlamaForCausalLM
|
|
10
9
|
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
11
10
|
|
|
12
|
-
from llmflowstack.base.base import BaseModel
|
|
13
11
|
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
12
|
+
from llmflowstack.decoders.BaseDecoder import BaseDecoder
|
|
14
13
|
from llmflowstack.schemas.params import GenerationParams
|
|
15
14
|
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
16
15
|
from llmflowstack.utils.generation_utils import create_generation_params
|
|
16
|
+
from llmflowstack.utils.logging import LogLevel
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class LLaMA3Input(TypedDict):
|
|
@@ -21,7 +21,7 @@ class LLaMA3Input(TypedDict):
|
|
|
21
21
|
expected_answer: str | None
|
|
22
22
|
system_message: str | None
|
|
23
23
|
|
|
24
|
-
class LLaMA3(
|
|
24
|
+
class LLaMA3(BaseDecoder):
|
|
25
25
|
model: LlamaForCausalLM | None = None
|
|
26
26
|
question_fields = ["input_text", "system_message"]
|
|
27
27
|
answer_fields = ["expected_answer"]
|
|
@@ -30,14 +30,12 @@ class LLaMA3(BaseModel):
|
|
|
30
30
|
self,
|
|
31
31
|
checkpoint: str | None = None,
|
|
32
32
|
quantization: Literal["4bit", "8bit"] | None = None,
|
|
33
|
-
seed: int | None = None
|
|
34
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
33
|
+
seed: int | None = None
|
|
35
34
|
) -> None:
|
|
36
35
|
return super().__init__(
|
|
37
36
|
checkpoint=checkpoint,
|
|
38
37
|
quantization=quantization,
|
|
39
|
-
seed=seed
|
|
40
|
-
log_level=log_level
|
|
38
|
+
seed=seed
|
|
41
39
|
)
|
|
42
40
|
|
|
43
41
|
def _set_generation_stopping_tokens(
|
|
@@ -45,7 +43,7 @@ class LLaMA3(BaseModel):
|
|
|
45
43
|
tokens: list[int]
|
|
46
44
|
) -> None:
|
|
47
45
|
if not self.tokenizer:
|
|
48
|
-
self._log("Could not set stop tokens - generation may not work...",
|
|
46
|
+
self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
|
|
49
47
|
return None
|
|
50
48
|
particular_tokens = self.tokenizer.encode("<|eot_id|>")
|
|
51
49
|
self.stop_token_ids = tokens + particular_tokens
|
|
@@ -92,7 +90,7 @@ class LLaMA3(BaseModel):
|
|
|
92
90
|
|
|
93
91
|
system_message = data.get("system_message", "")
|
|
94
92
|
|
|
95
|
-
return
|
|
93
|
+
return (
|
|
96
94
|
f"<|start_header_id|>system<|end_header_id|>{system_message}\n"
|
|
97
95
|
f"<|eot_id|><|start_header_id|>user<|end_header_id|>{data["input_text"]}\n"
|
|
98
96
|
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>{answer}"
|
|
@@ -119,7 +117,7 @@ class LLaMA3(BaseModel):
|
|
|
119
117
|
params: GenerationParams | None = None
|
|
120
118
|
) -> str | None:
|
|
121
119
|
if self.model is None or self.tokenizer is None:
|
|
122
|
-
self._log("Model or Tokenizer missing",
|
|
120
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
123
121
|
return None
|
|
124
122
|
|
|
125
123
|
self.model
|
|
@@ -184,7 +182,7 @@ class LLaMA3(BaseModel):
|
|
|
184
182
|
params: GenerationParams | None = None
|
|
185
183
|
) -> Iterator[str]:
|
|
186
184
|
if self.model is None or self.tokenizer is None:
|
|
187
|
-
self._log("Model or Tokenizer missing",
|
|
185
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
188
186
|
if False:
|
|
189
187
|
yield ""
|
|
190
188
|
return
|
{llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/LLaMA4.py
RENAMED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import threading
|
|
2
2
|
from functools import partial
|
|
3
3
|
from time import time
|
|
4
|
-
from typing import Iterator,
|
|
4
|
+
from typing import Iterator, TypedDict, cast
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
|
|
@@ -9,12 +9,13 @@ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
|
|
|
9
9
|
TrainingArguments)
|
|
10
10
|
from transformers.models.llama4 import Llama4ForCausalLM
|
|
11
11
|
|
|
12
|
-
from llmflowstack.base.base import BaseModel
|
|
13
12
|
from llmflowstack.callbacks.log_collector import LogCollectorCallback
|
|
14
13
|
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
14
|
+
from llmflowstack.decoders.BaseDecoder import BaseDecoder
|
|
15
15
|
from llmflowstack.schemas.params import GenerationParams, TrainParams
|
|
16
16
|
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
17
17
|
from llmflowstack.utils.generation_utils import create_generation_params
|
|
18
|
+
from llmflowstack.utils.logging import LogLevel
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class LLaMA4Input(TypedDict):
|
|
@@ -22,7 +23,7 @@ class LLaMA4Input(TypedDict):
|
|
|
22
23
|
expected_answer: str | None
|
|
23
24
|
system_message: str | None
|
|
24
25
|
|
|
25
|
-
class LLaMA4(
|
|
26
|
+
class LLaMA4(BaseDecoder):
|
|
26
27
|
model: Llama4ForCausalLM | None = None
|
|
27
28
|
question_fields = ["input_text", "system_message"]
|
|
28
29
|
answer_fields = ["expected_answer"]
|
|
@@ -30,14 +31,12 @@ class LLaMA4(BaseModel):
|
|
|
30
31
|
def __init__(
|
|
31
32
|
self,
|
|
32
33
|
checkpoint: str | None = None,
|
|
33
|
-
seed: int | None = None
|
|
34
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
34
|
+
seed: int | None = None
|
|
35
35
|
) -> None:
|
|
36
36
|
return super().__init__(
|
|
37
37
|
checkpoint=checkpoint,
|
|
38
38
|
quantization=None,
|
|
39
|
-
seed=seed
|
|
40
|
-
log_level=log_level
|
|
39
|
+
seed=seed
|
|
41
40
|
)
|
|
42
41
|
|
|
43
42
|
def _set_generation_stopping_tokens(
|
|
@@ -45,7 +44,7 @@ class LLaMA4(BaseModel):
|
|
|
45
44
|
tokens: list[int]
|
|
46
45
|
) -> None:
|
|
47
46
|
if not self.tokenizer:
|
|
48
|
-
self._log("Could not set stop tokens - generation may not work...",
|
|
47
|
+
self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
|
|
49
48
|
return None
|
|
50
49
|
particular_tokens = self.tokenizer.encode("<|eot|>")
|
|
51
50
|
self.stop_token_ids = tokens + particular_tokens
|
|
@@ -84,7 +83,8 @@ class LLaMA4(BaseModel):
|
|
|
84
83
|
system_message = f"<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>"
|
|
85
84
|
|
|
86
85
|
expected_answer = data.get("expected_answer")
|
|
87
|
-
answer =
|
|
86
|
+
answer = "<|header_start|>assistant<|header_end|>\n\n"
|
|
87
|
+
answer += f"{expected_answer}<|eot|>" if expected_answer else ""
|
|
88
88
|
|
|
89
89
|
return (
|
|
90
90
|
"<|begin_of_text|>"
|
|
@@ -118,16 +118,16 @@ class LLaMA4(BaseModel):
|
|
|
118
118
|
save_path: str | None = None
|
|
119
119
|
) -> None:
|
|
120
120
|
if not self.model:
|
|
121
|
-
self._log("Could not find a model loaded. Try loading a model first.",
|
|
121
|
+
self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
|
|
122
122
|
return None
|
|
123
123
|
if not self.tokenizer:
|
|
124
|
-
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.",
|
|
124
|
+
self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
|
|
125
125
|
return None
|
|
126
126
|
|
|
127
127
|
self._log("Starting DAPT")
|
|
128
128
|
|
|
129
129
|
if self.model_is_quantized:
|
|
130
|
-
self._log("Cannot DAPT a quantized model.",
|
|
130
|
+
self._log("Cannot DAPT a quantized model.", LogLevel.WARNING)
|
|
131
131
|
return None
|
|
132
132
|
|
|
133
133
|
if params is None:
|
|
@@ -184,7 +184,7 @@ class LLaMA4(BaseModel):
|
|
|
184
184
|
save_at_end = True,
|
|
185
185
|
save_path: str | None = None
|
|
186
186
|
) -> None:
|
|
187
|
-
self._log("Only 'dapt' method is available for this class. Redirecting call to it.",
|
|
187
|
+
self._log("Only 'dapt' method is available for this class. Redirecting call to it.", LogLevel.WARNING)
|
|
188
188
|
return self.dapt(
|
|
189
189
|
train_dataset=train_dataset,
|
|
190
190
|
params=params,
|
|
@@ -199,7 +199,7 @@ class LLaMA4(BaseModel):
|
|
|
199
199
|
params: GenerationParams | None = None
|
|
200
200
|
) -> str | None:
|
|
201
201
|
if self.model is None or self.tokenizer is None:
|
|
202
|
-
self._log("Model or Tokenizer missing",
|
|
202
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
203
203
|
return None
|
|
204
204
|
|
|
205
205
|
self.model
|
|
@@ -263,7 +263,7 @@ class LLaMA4(BaseModel):
|
|
|
263
263
|
params: GenerationParams | None = None
|
|
264
264
|
) -> Iterator[str]:
|
|
265
265
|
if self.model is None or self.tokenizer is None:
|
|
266
|
-
self._log("Model or Tokenizer missing",
|
|
266
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
267
267
|
if False:
|
|
268
268
|
yield ""
|
|
269
269
|
return
|
{llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/MedGemma.py
RENAMED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import textwrap
|
|
2
1
|
import threading
|
|
3
2
|
from functools import partial
|
|
4
3
|
from time import time
|
|
@@ -10,11 +9,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
|
|
|
10
9
|
from transformers.models.gemma3 import Gemma3ForCausalLM
|
|
11
10
|
from transformers.utils.quantization_config import BitsAndBytesConfig
|
|
12
11
|
|
|
13
|
-
from llmflowstack.base.base import BaseModel
|
|
14
12
|
from llmflowstack.callbacks.stop_on_token import StopOnToken
|
|
13
|
+
from llmflowstack.decoders.BaseDecoder import BaseDecoder
|
|
15
14
|
from llmflowstack.schemas.params import GenerationParams
|
|
16
15
|
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
17
16
|
from llmflowstack.utils.generation_utils import create_generation_params
|
|
17
|
+
from llmflowstack.utils.logging import LogLevel
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class MedGemmaInput(TypedDict):
|
|
@@ -22,7 +22,7 @@ class MedGemmaInput(TypedDict):
|
|
|
22
22
|
expected_answer: str | None
|
|
23
23
|
system_message: str | None
|
|
24
24
|
|
|
25
|
-
class MedGemma(
|
|
25
|
+
class MedGemma(BaseDecoder):
|
|
26
26
|
model: Gemma3ForCausalLM | None = None
|
|
27
27
|
can_think = False
|
|
28
28
|
question_fields = ["input_text", "system_message"]
|
|
@@ -32,14 +32,12 @@ class MedGemma(BaseModel):
|
|
|
32
32
|
self,
|
|
33
33
|
checkpoint: str | None = None,
|
|
34
34
|
quantization: Literal["4bit"] | None = None,
|
|
35
|
-
seed: int | None = None
|
|
36
|
-
log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
|
|
35
|
+
seed: int | None = None
|
|
37
36
|
) -> None:
|
|
38
37
|
return super().__init__(
|
|
39
38
|
checkpoint=checkpoint,
|
|
40
39
|
quantization=quantization,
|
|
41
|
-
seed=seed
|
|
42
|
-
log_level=log_level
|
|
40
|
+
seed=seed
|
|
43
41
|
)
|
|
44
42
|
|
|
45
43
|
def _set_generation_stopping_tokens(
|
|
@@ -47,7 +45,7 @@ class MedGemma(BaseModel):
|
|
|
47
45
|
tokens: list[int]
|
|
48
46
|
) -> None:
|
|
49
47
|
if not self.tokenizer:
|
|
50
|
-
self._log("Could not set stop tokens - generation may not work...",
|
|
48
|
+
self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
|
|
51
49
|
return None
|
|
52
50
|
particular_tokens = self.tokenizer.encode("<end_of_turn>")
|
|
53
51
|
self.stop_token_ids = tokens + particular_tokens
|
|
@@ -128,7 +126,7 @@ class MedGemma(BaseModel):
|
|
|
128
126
|
params: GenerationParams | None = None,
|
|
129
127
|
) -> str | None:
|
|
130
128
|
if self.model is None or self.tokenizer is None:
|
|
131
|
-
self._log("Model or Tokenizer missing",
|
|
129
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
132
130
|
return None
|
|
133
131
|
|
|
134
132
|
self._log(f"Processing received input...'")
|
|
@@ -196,7 +194,7 @@ class MedGemma(BaseModel):
|
|
|
196
194
|
params: GenerationParams | None = None
|
|
197
195
|
) -> Iterator[str]:
|
|
198
196
|
if self.model is None or self.tokenizer is None:
|
|
199
|
-
self._log("Model or Tokenizer missing",
|
|
197
|
+
self._log("Model or Tokenizer missing", LogLevel.WARNING)
|
|
200
198
|
if False:
|
|
201
199
|
yield ""
|
|
202
200
|
return
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import uuid
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
|
|
5
|
+
import chromadb
|
|
6
|
+
import chromadb.config
|
|
7
|
+
import torch
|
|
8
|
+
from langchain_chroma import Chroma
|
|
9
|
+
from langchain_core.documents import Document
|
|
10
|
+
from langchain_core.embeddings import Embeddings
|
|
11
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
12
|
+
from sentence_transformers import SentenceTransformer
|
|
13
|
+
|
|
14
|
+
from llmflowstack.utils.exceptions import MissingEssentialProp
|
|
15
|
+
from llmflowstack.utils.logging import LogLevel
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EncoderWrapper(Embeddings):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: SentenceTransformer
|
|
22
|
+
) -> None:
|
|
23
|
+
self.model = model
|
|
24
|
+
|
|
25
|
+
def embed_documents(
|
|
26
|
+
self,
|
|
27
|
+
texts: list[str]
|
|
28
|
+
) -> list[list[float]]:
|
|
29
|
+
vectors = self.model.encode(texts, task="retrieval", show_progress_bar=False)
|
|
30
|
+
return vectors.tolist()
|
|
31
|
+
|
|
32
|
+
def embed_query(
|
|
33
|
+
self,
|
|
34
|
+
text: str
|
|
35
|
+
) -> list[float]:
|
|
36
|
+
vectors = self.model.encode(text, task="retrieval", show_progress_bar=False)
|
|
37
|
+
return vectors.tolist()
|
|
38
|
+
|
|
39
|
+
class VectorDatabase:
|
|
40
|
+
encoder: SentenceTransformer | None = None
|
|
41
|
+
collections: dict[str, Chroma] = {}
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
checkpoint: str | None = None,
|
|
46
|
+
chunk_size: int = 1000,
|
|
47
|
+
chunk_overlap: int = 200
|
|
48
|
+
) -> None:
|
|
49
|
+
self.logger = getLogger(f"LLMFlowStack.{self.__class__.__name__}")
|
|
50
|
+
|
|
51
|
+
self.encoder = None
|
|
52
|
+
if checkpoint:
|
|
53
|
+
self.load_encoder(
|
|
54
|
+
checkpoint=checkpoint
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
self.splitter = RecursiveCharacterTextSplitter(
|
|
58
|
+
chunk_size=chunk_size,
|
|
59
|
+
chunk_overlap=chunk_overlap,
|
|
60
|
+
add_start_index=True,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def _log(
|
|
64
|
+
self,
|
|
65
|
+
message: str,
|
|
66
|
+
level: LogLevel = LogLevel.INFO,
|
|
67
|
+
) -> None:
|
|
68
|
+
log_func = getattr(self.logger, level.lower(), None)
|
|
69
|
+
if log_func:
|
|
70
|
+
log_func(message)
|
|
71
|
+
else:
|
|
72
|
+
self.logger.info(message)
|
|
73
|
+
|
|
74
|
+
def load_encoder(
|
|
75
|
+
self,
|
|
76
|
+
checkpoint: str
|
|
77
|
+
) -> None:
|
|
78
|
+
if self.encoder:
|
|
79
|
+
self._log("A encoder is already loaded. Attempting to reset it.", LogLevel.WARNING)
|
|
80
|
+
self.unload_encoder()
|
|
81
|
+
|
|
82
|
+
self._log(f"Loading encoder on '{checkpoint}'")
|
|
83
|
+
self.encoder = SentenceTransformer(
|
|
84
|
+
checkpoint,
|
|
85
|
+
trust_remote_code=True
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
self._log("Encoder loaded")
|
|
89
|
+
|
|
90
|
+
def unload_encoder(
|
|
91
|
+
self
|
|
92
|
+
) -> None:
|
|
93
|
+
try:
|
|
94
|
+
del self.encoder
|
|
95
|
+
gc.collect()
|
|
96
|
+
torch.cuda.empty_cache()
|
|
97
|
+
self.encoder = None
|
|
98
|
+
self._log("Reset successfully.")
|
|
99
|
+
except Exception as e:
|
|
100
|
+
self._log("Couldn't reset encoder...", LogLevel.ERROR)
|
|
101
|
+
self._log(f"{str(e)}", LogLevel.DEBUG)
|
|
102
|
+
|
|
103
|
+
def get_collection(
|
|
104
|
+
self,
|
|
105
|
+
collection_name: str = "rag_memory",
|
|
106
|
+
persist_directory: str | None = None
|
|
107
|
+
) -> None:
|
|
108
|
+
if not self.encoder:
|
|
109
|
+
raise MissingEssentialProp("Could not find encoder.")
|
|
110
|
+
|
|
111
|
+
client_settings = chromadb.config.Settings(
|
|
112
|
+
anonymized_telemetry=False
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.collections[collection_name] = Chroma(
|
|
116
|
+
collection_name=collection_name,
|
|
117
|
+
embedding_function=EncoderWrapper(self.encoder),
|
|
118
|
+
persist_directory=persist_directory,
|
|
119
|
+
client_settings=client_settings
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def validate_collection_name(
|
|
123
|
+
self,
|
|
124
|
+
collection_name: str
|
|
125
|
+
) -> None:
|
|
126
|
+
if collection_name not in self.collections:
|
|
127
|
+
raise ValueError("Collection name not found in collection")
|
|
128
|
+
|
|
129
|
+
def index_documents(
|
|
130
|
+
self,
|
|
131
|
+
collection_name: str,
|
|
132
|
+
docs: list[Document],
|
|
133
|
+
ids: list[str],
|
|
134
|
+
can_split: bool = True,
|
|
135
|
+
) -> None:
|
|
136
|
+
self.validate_collection_name(
|
|
137
|
+
collection_name=collection_name
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
for doc, src_id in zip(docs, ids):
|
|
141
|
+
if doc.metadata is None:
|
|
142
|
+
doc.metadata = {}
|
|
143
|
+
doc.metadata["source_id"] = src_id
|
|
144
|
+
|
|
145
|
+
if can_split:
|
|
146
|
+
splits = self.splitter.split_documents(docs)
|
|
147
|
+
else:
|
|
148
|
+
splits = docs
|
|
149
|
+
|
|
150
|
+
split_ids = []
|
|
151
|
+
metadatas = []
|
|
152
|
+
texts = []
|
|
153
|
+
|
|
154
|
+
for i, s in enumerate(splits):
|
|
155
|
+
src = s.metadata.get("source_id", "unknown")
|
|
156
|
+
sid = f"{src}_{i}"
|
|
157
|
+
split_ids.append(sid)
|
|
158
|
+
metadatas.append(s.metadata.copy())
|
|
159
|
+
texts.append(s.page_content)
|
|
160
|
+
|
|
161
|
+
self.collections[collection_name].add_texts(
|
|
162
|
+
texts=texts,
|
|
163
|
+
ids=split_ids,
|
|
164
|
+
metadatas=metadatas
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def create(
|
|
168
|
+
self,
|
|
169
|
+
collection_name: str,
|
|
170
|
+
information: str,
|
|
171
|
+
other_info: dict[str, str] | None = None,
|
|
172
|
+
doc_id: str | None = None,
|
|
173
|
+
should_index: bool = True,
|
|
174
|
+
can_split: bool = True
|
|
175
|
+
) -> Document:
|
|
176
|
+
if other_info is None:
|
|
177
|
+
other_info = {}
|
|
178
|
+
|
|
179
|
+
if doc_id is None:
|
|
180
|
+
doc_id = str(uuid.uuid4())
|
|
181
|
+
|
|
182
|
+
metadata = {"source_id": doc_id, **other_info}
|
|
183
|
+
doc = Document(
|
|
184
|
+
page_content=information,
|
|
185
|
+
metadata=metadata
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if should_index:
|
|
189
|
+
self.index_documents(
|
|
190
|
+
collection_name=collection_name,
|
|
191
|
+
docs=[doc],
|
|
192
|
+
ids=[doc_id],
|
|
193
|
+
can_split=can_split
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return doc
|
|
197
|
+
|
|
198
|
+
def update(
|
|
199
|
+
self,
|
|
200
|
+
collection_name: str,
|
|
201
|
+
doc_id: str,
|
|
202
|
+
new_information: str,
|
|
203
|
+
other_info: dict[str, str] | None = None
|
|
204
|
+
) -> Document:
|
|
205
|
+
self.validate_collection_name(
|
|
206
|
+
collection_name=collection_name
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if other_info is None:
|
|
210
|
+
other_info = {}
|
|
211
|
+
|
|
212
|
+
documents_to_delete = self.collections[collection_name].get(
|
|
213
|
+
where={
|
|
214
|
+
"source_id": doc_id
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
ids_to_delete = documents_to_delete.get("ids", [])
|
|
219
|
+
|
|
220
|
+
if ids_to_delete:
|
|
221
|
+
self.collections[collection_name].delete(ids=ids_to_delete)
|
|
222
|
+
|
|
223
|
+
return self.create(
|
|
224
|
+
collection_name=collection_name,
|
|
225
|
+
information=new_information,
|
|
226
|
+
other_info=other_info,
|
|
227
|
+
doc_id=doc_id
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def delete(
|
|
231
|
+
self,
|
|
232
|
+
collection_name: str,
|
|
233
|
+
doc_id: str
|
|
234
|
+
) -> None:
|
|
235
|
+
self.validate_collection_name(
|
|
236
|
+
collection_name=collection_name
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
self.collections[collection_name].delete(ids=[doc_id])
|
|
240
|
+
|
|
241
|
+
def rquery(
|
|
242
|
+
self,
|
|
243
|
+
collection_name: str,
|
|
244
|
+
query: str,
|
|
245
|
+
k: int = 4,
|
|
246
|
+
filter: dict | None = None
|
|
247
|
+
) -> list[Document]:
|
|
248
|
+
self.validate_collection_name(
|
|
249
|
+
collection_name=collection_name
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return self.collections[collection_name].similarity_search(
|
|
253
|
+
query=query,
|
|
254
|
+
k=k,
|
|
255
|
+
filter=filter
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def query(
|
|
259
|
+
self,
|
|
260
|
+
collection_name: str,
|
|
261
|
+
query: str,
|
|
262
|
+
k: int = 4,
|
|
263
|
+
filter: dict | None = None
|
|
264
|
+
) -> str:
|
|
265
|
+
self.validate_collection_name(
|
|
266
|
+
collection_name=collection_name
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
if filter:
|
|
270
|
+
docs = self.collections[collection_name].similarity_search(
|
|
271
|
+
query=query,
|
|
272
|
+
k=k,
|
|
273
|
+
filter=filter
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
docs = self.collections[collection_name].similarity_search(query, k=k)
|
|
277
|
+
|
|
278
|
+
return "\n\n".join(doc.page_content for doc in docs)
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "llmflowstack"
|
|
7
|
-
version = "1.
|
|
7
|
+
version = "1.2.0"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name = "Gustavo Henrique Ferreira Cruz", email = "gustavohferreiracruz@gmail.com" }
|
|
10
10
|
]
|
|
@@ -17,7 +17,6 @@ dependencies = [
|
|
|
17
17
|
"accelerate",
|
|
18
18
|
"bert-score",
|
|
19
19
|
"bitsandbytes",
|
|
20
|
-
"colorama",
|
|
21
20
|
"chromadb",
|
|
22
21
|
"datasets",
|
|
23
22
|
"evaluate",
|
|
File without changes
|
|
@@ -1,279 +0,0 @@
|
|
|
1
|
-
import uuid
|
|
2
|
-
|
|
3
|
-
import chromadb
|
|
4
|
-
import chromadb.config
|
|
5
|
-
from langchain_chroma import Chroma
|
|
6
|
-
from langchain_core.documents import Document
|
|
7
|
-
from langchain_core.embeddings import Embeddings
|
|
8
|
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
9
|
-
from sentence_transformers import SentenceTransformer
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class EncoderWrapper(Embeddings):
|
|
13
|
-
def __init__(
|
|
14
|
-
self,
|
|
15
|
-
model: SentenceTransformer
|
|
16
|
-
) -> None:
|
|
17
|
-
self.model = model
|
|
18
|
-
|
|
19
|
-
def embed_documents(
|
|
20
|
-
self,
|
|
21
|
-
texts: list[str]
|
|
22
|
-
) -> list[list[float]]:
|
|
23
|
-
vectors = self.model.encode(texts, task="retrieval", show_progress_bar=False)
|
|
24
|
-
return vectors.tolist()
|
|
25
|
-
|
|
26
|
-
def embed_query(
|
|
27
|
-
self,
|
|
28
|
-
text: str
|
|
29
|
-
) -> list[float]:
|
|
30
|
-
vectors = self.model.encode(text, task="retrieval", show_progress_bar=False)
|
|
31
|
-
return vectors.tolist()
|
|
32
|
-
|
|
33
|
-
class RAGPipeline:
|
|
34
|
-
"""
|
|
35
|
-
A modular Retrieval-Augmented Generation (RAG) pipeline for embedding, indexing, and retrieving scientific or textual data using SentenceTransformers and Chroma as a vector store.
|
|
36
|
-
|
|
37
|
-
Supports both persistent (disk-based) and transient (in-memory) modes depending on whether `persist_directory` is provided.
|
|
38
|
-
"""
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
checkpoint: str,
|
|
42
|
-
collection_name: str = "rag_memory",
|
|
43
|
-
persist_directory: str | None = None,
|
|
44
|
-
chunk_size: int = 1000,
|
|
45
|
-
chunk_overlap: int = 200
|
|
46
|
-
) -> None:
|
|
47
|
-
"""
|
|
48
|
-
Initializes the RAG pipeline.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
checkpoint (str): Path or name of the SentenceTransformer checkpoint.
|
|
52
|
-
collection_name (str): Name of the Chroma collection to create or load.
|
|
53
|
-
persist_directory (str | None): Directory where the vector database is stored. If None, all data is kept in-memory and discarded after the session ends.
|
|
54
|
-
chunk_size (int): Maximum size (in characters) for text chunks during indexing.
|
|
55
|
-
chunk_overlap (int): Overlap (in characters) between consecutive text chunks.
|
|
56
|
-
"""
|
|
57
|
-
self.encoder = SentenceTransformer(checkpoint, trust_remote_code=True)
|
|
58
|
-
|
|
59
|
-
client_settings = chromadb.config.Settings(
|
|
60
|
-
anonymized_telemetry=False
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
self.collection = Chroma(
|
|
64
|
-
collection_name=collection_name,
|
|
65
|
-
embedding_function=EncoderWrapper(self.encoder),
|
|
66
|
-
persist_directory=persist_directory,
|
|
67
|
-
client_settings=client_settings
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
self.splitter = RecursiveCharacterTextSplitter(
|
|
71
|
-
chunk_size=chunk_size,
|
|
72
|
-
chunk_overlap=chunk_overlap,
|
|
73
|
-
add_start_index=True,
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
def index_documents(
|
|
77
|
-
self,
|
|
78
|
-
docs: list[Document],
|
|
79
|
-
ids: list[str],
|
|
80
|
-
can_split: bool = True
|
|
81
|
-
) -> None:
|
|
82
|
-
"""
|
|
83
|
-
Indexes a list of documents into the Chroma vector store.
|
|
84
|
-
|
|
85
|
-
Each document is assigned a unique `source_id` and, optionally, split into smaller chunks for more granular retrieval. Each resulting chunk is embedded and stored with its metadata for later similarity search.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
docs (list[Document]): List of LangChain `Document` objects to index.
|
|
89
|
-
ids (list[str]): Unique identifiers corresponding to each document.
|
|
90
|
-
can_split (bool): Whether to split documents into smaller chunks before
|
|
91
|
-
indexing. Set to False to index each document as a single entry
|
|
92
|
-
(e.g., for short or self-contained texts).
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
None
|
|
96
|
-
"""
|
|
97
|
-
for doc, src_id in zip(docs, ids):
|
|
98
|
-
if doc.metadata is None:
|
|
99
|
-
doc.metadata = {}
|
|
100
|
-
doc.metadata["source_id"] = src_id
|
|
101
|
-
|
|
102
|
-
if can_split:
|
|
103
|
-
splits = self.splitter.split_documents(docs)
|
|
104
|
-
else:
|
|
105
|
-
splits = docs
|
|
106
|
-
|
|
107
|
-
split_ids = []
|
|
108
|
-
metadatas = []
|
|
109
|
-
texts = []
|
|
110
|
-
|
|
111
|
-
for i, s in enumerate(splits):
|
|
112
|
-
src = s.metadata.get("source_id", "unknown")
|
|
113
|
-
sid = f"{src}_{i}"
|
|
114
|
-
split_ids.append(sid)
|
|
115
|
-
metadatas.append(s.metadata.copy())
|
|
116
|
-
texts.append(s.page_content)
|
|
117
|
-
|
|
118
|
-
self.collection.add_texts(
|
|
119
|
-
texts=texts,
|
|
120
|
-
ids=split_ids,
|
|
121
|
-
metadatas=metadatas
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
def create(
|
|
125
|
-
self,
|
|
126
|
-
information: str,
|
|
127
|
-
other_info: dict[str, str] | None = None,
|
|
128
|
-
doc_id: str | None = None,
|
|
129
|
-
should_index: bool = True,
|
|
130
|
-
can_split: bool = True
|
|
131
|
-
) -> Document:
|
|
132
|
-
"""
|
|
133
|
-
Creates a new `Document` and optionally indexes it in the collection.
|
|
134
|
-
|
|
135
|
-
This is a convenience method that wraps both document creation and embedding/indexing in one step. Metadata fields are merged into the document and can include any descriptive information (e.g., title, DOI, year).
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
information (str): Main textual content of the document.
|
|
139
|
-
other_info (dict[str, str] | None): Optional metadata fields to include.
|
|
140
|
-
doc_id (str | None): Custom document identifier. If None, a UUID is generated.
|
|
141
|
-
should_index (bool): Whether to immediately add the document to the vector store.
|
|
142
|
-
can_split (bool): Whether to allow splitting before indexing.
|
|
143
|
-
|
|
144
|
-
Returns:
|
|
145
|
-
Document: The created LangChain `Document` object (indexed if specified).
|
|
146
|
-
"""
|
|
147
|
-
if other_info is None:
|
|
148
|
-
other_info = {}
|
|
149
|
-
|
|
150
|
-
if doc_id is None:
|
|
151
|
-
doc_id = str(uuid.uuid4())
|
|
152
|
-
|
|
153
|
-
metadata = {"source_id": doc_id, **other_info}
|
|
154
|
-
doc = Document(
|
|
155
|
-
page_content=information,
|
|
156
|
-
metadata=metadata
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
if should_index:
|
|
160
|
-
self.index_documents(
|
|
161
|
-
docs=[doc],
|
|
162
|
-
ids=[doc_id],
|
|
163
|
-
can_split=can_split
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
return doc
|
|
167
|
-
|
|
168
|
-
def update(
|
|
169
|
-
self,
|
|
170
|
-
doc_id: str,
|
|
171
|
-
new_information: str,
|
|
172
|
-
other_info: dict[str, str] | None = None
|
|
173
|
-
) -> Document:
|
|
174
|
-
"""
|
|
175
|
-
Updates an existing document in the collection with new content and metadata.
|
|
176
|
-
|
|
177
|
-
All vector entries associated with the provided `doc_id` are deleted, and a new document is created and re-indexed in their place. This ensures that embeddings remain consistent with the latest text content.
|
|
178
|
-
|
|
179
|
-
Args:
|
|
180
|
-
doc_id (str): Identifier of the document to update.
|
|
181
|
-
new_information (str): Updated text content for the document.
|
|
182
|
-
other_info (dict[str, str] | None): Optional new metadata to associate.
|
|
183
|
-
|
|
184
|
-
Returns:
|
|
185
|
-
Document: The newly created (updated) `Document` object.
|
|
186
|
-
"""
|
|
187
|
-
if other_info is None:
|
|
188
|
-
other_info = {}
|
|
189
|
-
|
|
190
|
-
documents_to_delete = self.collection.get(
|
|
191
|
-
where={
|
|
192
|
-
"source_id": doc_id
|
|
193
|
-
}
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
ids_to_delete = documents_to_delete.get("ids", [])
|
|
197
|
-
|
|
198
|
-
if ids_to_delete:
|
|
199
|
-
self.collection.delete(ids=ids_to_delete)
|
|
200
|
-
|
|
201
|
-
return self.create(
|
|
202
|
-
information=new_information,
|
|
203
|
-
other_info=other_info,
|
|
204
|
-
doc_id=doc_id
|
|
205
|
-
)
|
|
206
|
-
|
|
207
|
-
def delete(
|
|
208
|
-
self,
|
|
209
|
-
doc_id: str
|
|
210
|
-
) -> None:
|
|
211
|
-
"""
|
|
212
|
-
Deletes all indexed entries associated with a specific document ID.
|
|
213
|
-
|
|
214
|
-
Removes all vectors and metadata tied to the provided `doc_id` from the collection. Use this to completely erase a document's content from the indexed database.
|
|
215
|
-
|
|
216
|
-
Args:
|
|
217
|
-
doc_id (str): Identifier of the document to delete.
|
|
218
|
-
|
|
219
|
-
Returns:
|
|
220
|
-
None
|
|
221
|
-
"""
|
|
222
|
-
self.collection.delete(ids=[doc_id])
|
|
223
|
-
|
|
224
|
-
def rquery(
|
|
225
|
-
self,
|
|
226
|
-
query: str,
|
|
227
|
-
k: int = 4,
|
|
228
|
-
filter: dict | None = None
|
|
229
|
-
) -> list[Document]:
|
|
230
|
-
"""
|
|
231
|
-
Perform a **raw semantic search** on the collection.
|
|
232
|
-
|
|
233
|
-
This method queries the vector store using the provided text query and returns the top-`k` most similar `Document` objects, optionally filtered by metadata.
|
|
234
|
-
|
|
235
|
-
Args:
|
|
236
|
-
query (str): The natural-language query text to embed and search for.
|
|
237
|
-
k (int, optional): Number of top results to return. Defaults to 4.
|
|
238
|
-
filter (dict | None, optional): Metadata filter applied during search
|
|
239
|
-
(e.g., {"type": "article"}). Defaults to None.
|
|
240
|
-
|
|
241
|
-
Returns:
|
|
242
|
-
list[Document]: A list of matching documents sorted by similarity score.
|
|
243
|
-
"""
|
|
244
|
-
return self.collection.similarity_search(
|
|
245
|
-
query=query,
|
|
246
|
-
k=k,
|
|
247
|
-
filter=filter
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
def query(
|
|
251
|
-
self,
|
|
252
|
-
query: str,
|
|
253
|
-
k: int = 4,
|
|
254
|
-
filter: dict | None = None
|
|
255
|
-
) -> str:
|
|
256
|
-
"""
|
|
257
|
-
Perform a **semantic search** and return the combined text content.
|
|
258
|
-
|
|
259
|
-
This method wraps `rquery()` and concatenates the retrieved document contents into a single string, suitable for direct use in downstream LLM prompts or text processing.
|
|
260
|
-
|
|
261
|
-
Args:
|
|
262
|
-
query (str): The natural-language query text to search for.
|
|
263
|
-
k (int, optional): Number of top results to return. Defaults to 4.
|
|
264
|
-
filter (dict | None, optional): Metadata filter applied during search. If None, all documents are considered.
|
|
265
|
-
|
|
266
|
-
Returns:
|
|
267
|
-
str: A newline-separated string containing the page contents of
|
|
268
|
-
the retrieved documents.
|
|
269
|
-
"""
|
|
270
|
-
if filter:
|
|
271
|
-
docs = self.collection.similarity_search(
|
|
272
|
-
query=query,
|
|
273
|
-
k=k,
|
|
274
|
-
filter=filter
|
|
275
|
-
)
|
|
276
|
-
else:
|
|
277
|
-
docs = self.collection.similarity_search(query, k=k)
|
|
278
|
-
|
|
279
|
-
return "\n\n".join(doc.page_content for doc in docs)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{llmflowstack-1.1.3/llmflowstack/base → llmflowstack-1.2.0/llmflowstack/callbacks}/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|