llmflowstack 1.1.3__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/PKG-INFO +1 -2
  2. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/__init__.py +8 -8
  3. llmflowstack-1.1.3/llmflowstack/base/base.py → llmflowstack-1.2.0/llmflowstack/decoders/BaseDecoder.py +27 -56
  4. {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/GPT_OSS.py +14 -12
  5. {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/Gemma.py +12 -13
  6. {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/LLaMA3.py +9 -11
  7. {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/LLaMA4.py +15 -15
  8. {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/MedGemma.py +8 -10
  9. llmflowstack-1.2.0/llmflowstack/rag/VectorDatabase.py +278 -0
  10. llmflowstack-1.2.0/llmflowstack/rag/__init__.py +5 -0
  11. llmflowstack-1.2.0/llmflowstack/utils/logging.py +8 -0
  12. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/pyproject.toml +1 -2
  13. llmflowstack-1.1.3/llmflowstack/callbacks/__init__.py +0 -0
  14. llmflowstack-1.1.3/llmflowstack/rag/__iinit__.py +0 -5
  15. llmflowstack-1.1.3/llmflowstack/rag/pipeline.py +0 -279
  16. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/.github/workflows/python-publish.yml +0 -0
  17. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/.gitignore +0 -0
  18. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/LICENSE +0 -0
  19. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/README.md +0 -0
  20. {llmflowstack-1.1.3/llmflowstack/base → llmflowstack-1.2.0/llmflowstack/callbacks}/__init__.py +0 -0
  21. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/callbacks/log_collector.py +0 -0
  22. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/callbacks/stop_on_token.py +0 -0
  23. {llmflowstack-1.1.3/llmflowstack/models → llmflowstack-1.2.0/llmflowstack/decoders}/__init__.py +0 -0
  24. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/schemas/__init__.py +0 -0
  25. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/schemas/params.py +0 -0
  26. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/__init__.py +0 -0
  27. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/evaluation_methods.py +0 -0
  28. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/exceptions.py +0 -0
  29. {llmflowstack-1.1.3 → llmflowstack-1.2.0}/llmflowstack/utils/generation_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llmflowstack
3
- Version: 1.1.3
3
+ Version: 1.2.0
4
4
  Summary: LLMFlowStack is a framework for training and using LLMs (LLaMA, GPT-OSS, Gemma, ...). Supports DAPT, fine-tuning, and distributed inference. Public fork without institution-specific components.
5
5
  Author-email: Gustavo Henrique Ferreira Cruz <gustavohferreiracruz@gmail.com>
6
6
  License: MIT
@@ -10,7 +10,6 @@ Requires-Dist: accelerate
10
10
  Requires-Dist: bert-score
11
11
  Requires-Dist: bitsandbytes
12
12
  Requires-Dist: chromadb
13
- Requires-Dist: colorama
14
13
  Requires-Dist: datasets
15
14
  Requires-Dist: evaluate
16
15
  Requires-Dist: huggingface-hub
@@ -1,9 +1,9 @@
1
- from .models.Gemma import Gemma3
2
- from .models.GPT_OSS import GPT_OSS
3
- from .models.LLaMA3 import LLaMA3
4
- from .models.LLaMA4 import LLaMA4
5
- from .models.MedGemma import MedGemma
6
- from .rag.pipeline import RAGPipeline
1
+ from .decoders.Gemma import Gemma3
2
+ from .decoders.GPT_OSS import GPT_OSS
3
+ from .decoders.LLaMA3 import LLaMA3
4
+ from .decoders.LLaMA4 import LLaMA4
5
+ from .decoders.MedGemma import MedGemma
6
+ from .rag import VectorDatabase
7
7
  from .schemas.params import (GenerationBeamsParams, GenerationParams,
8
8
  GenerationSampleParams, TrainParams)
9
9
  from .utils.evaluation_methods import text_evaluation
@@ -14,10 +14,10 @@ __all__ = [
14
14
  "LLaMA3",
15
15
  "LLaMA4",
16
16
  "MedGemma",
17
- "RAGPipeline",
18
17
  "GenerationBeamsParams",
19
18
  "GenerationParams",
20
19
  "GenerationSampleParams",
21
20
  "TrainParams",
22
- "text_evaluation"
21
+ "text_evaluation",
22
+ "VectorDatabase"
23
23
  ]
@@ -1,15 +1,14 @@
1
1
  import gc
2
2
  import json
3
- import logging
4
3
  import os
5
4
  import random
6
5
  from abc import ABC, abstractmethod
6
+ from logging import getLogger
7
7
  from typing import Any, Literal, cast
8
8
  from uuid import uuid4
9
9
 
10
10
  import numpy as np
11
11
  import torch
12
- from colorama import Fore, Style, init
13
12
  from datasets import Dataset
14
13
  from torch import Tensor
15
14
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
@@ -20,15 +19,15 @@ from trl.trainer.sft_trainer import SFTTrainer
20
19
  from llmflowstack.callbacks.log_collector import LogCollectorCallback
21
20
  from llmflowstack.schemas.params import GenerationParams, TrainParams
22
21
  from llmflowstack.utils.exceptions import MissingEssentialProp
22
+ from llmflowstack.utils.logging import LogLevel
23
23
 
24
24
 
25
- class BaseModel(ABC):
25
+ class BaseDecoder(ABC):
26
26
  model = None
27
27
  tokenizer = None
28
28
  _model_id = None
29
29
  model_is_quantized = None
30
30
  seed = None
31
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO"
32
31
  stop_token_ids = []
33
32
  question_fields = []
34
33
  answer_fields = []
@@ -37,20 +36,17 @@ class BaseModel(ABC):
37
36
  self,
38
37
  checkpoint: str | None = None,
39
38
  quantization: Literal["4bit", "8bit"] | bool | None = None,
40
- seed: int | None = None,
41
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
39
+ seed: int | None = None
42
40
  ) -> None:
43
41
  if not self.question_fields or not self.answer_fields:
44
42
  raise NotImplementedError("Subclasses must define question_fields and answer_fields.")
45
43
 
46
- init(autoreset=True)
47
44
  if seed:
48
45
  self._set_seed(seed)
49
46
 
50
47
  self._base_model = checkpoint
51
48
 
52
- self._set_logger(log_level)
53
- self.log_level = log_level
49
+ self.logger = getLogger(f"LLMFlowStack.{self.__class__.__name__}")
54
50
 
55
51
  self.tokenizer: PreTrainedTokenizerBase | None = None
56
52
 
@@ -61,6 +57,17 @@ class BaseModel(ABC):
61
57
  quantization=quantization
62
58
  )
63
59
 
60
+ def _log(
61
+ self,
62
+ message: str,
63
+ level: LogLevel = LogLevel.INFO,
64
+ ) -> None:
65
+ log_func = getattr(self.logger, level.lower(), None)
66
+ if log_func:
67
+ log_func(message)
68
+ else:
69
+ self.logger.info(message)
70
+
64
71
  @abstractmethod
65
72
  def _load_model(
66
73
  self,
@@ -84,7 +91,7 @@ class BaseModel(ABC):
84
91
  quantization: Any
85
92
  ) -> None:
86
93
  if self.model:
87
- self._log("A model is already loaded. Attempting to reset it.", "WARNING")
94
+ self._log("A model is already loaded. Attempting to reset it.", LogLevel.WARNING)
88
95
  self.unload_model()
89
96
 
90
97
  self._log(f"Loading model on '{checkpoint}'")
@@ -132,42 +139,6 @@ class BaseModel(ABC):
132
139
  ) -> None:
133
140
  self._model_id = uuid4()
134
141
 
135
- def _set_logger(
136
- self,
137
- level: str
138
- ) -> None:
139
- level_map = {
140
- "DEBUG": logging.DEBUG,
141
- "INFO": logging.INFO,
142
- "WARNING": logging.WARNING,
143
- "ERROR": logging.ERROR,
144
- }
145
- numeric_level = level_map.get(level.upper(), logging.INFO)
146
-
147
- logging.basicConfig(
148
- level=numeric_level,
149
- format="%(asctime)s - %(levelname)s - %(message)s"
150
- )
151
- self.logger = logging.getLogger(__name__)
152
-
153
- def _log(
154
- self,
155
- info: str,
156
- level: Literal["INFO", "WARNING", "ERROR", "DEBUG"] = "INFO"
157
- ) -> None:
158
- if level == "INFO":
159
- colored_msg = f"{Fore.GREEN}{info}{Style.RESET_ALL}"
160
- self.logger.info(colored_msg)
161
- elif level == "WARNING":
162
- colored_msg = f"{Fore.YELLOW}{info}{Style.RESET_ALL}"
163
- self.logger.warning(colored_msg)
164
- elif level == "ERROR":
165
- colored_msg = f"{Fore.RED}{info}{Style.RESET_ALL}"
166
- self.logger.error(colored_msg)
167
- elif level == "DEBUG":
168
- colored_msg = f"{Fore.BLUE}{info}{Style.RESET_ALL}"
169
- self.logger.debug(colored_msg)
170
-
171
142
  def _set_seed(
172
143
  self,
173
144
  seed: int
@@ -190,10 +161,10 @@ class BaseModel(ABC):
190
161
  path: str
191
162
  ) -> None:
192
163
  if not self.model:
193
- self._log("No model to save.", "WARNING")
164
+ self._log("No model to save.", LogLevel.WARNING)
194
165
  return None
195
166
  if not self.tokenizer:
196
- self._log("No tokenizer to save.", "WARNING")
167
+ self._log("No tokenizer to save.", LogLevel.WARNING)
197
168
  return None
198
169
 
199
170
  os.makedirs(path, exist_ok=True)
@@ -299,16 +270,16 @@ class BaseModel(ABC):
299
270
  save_path: str | None = None
300
271
  ) -> None:
301
272
  if not self.model:
302
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
273
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
303
274
  return None
304
275
  if not self.tokenizer:
305
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
276
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
306
277
  return None
307
278
 
308
279
  self._log("Starting DAPT")
309
280
 
310
281
  if self.model_is_quantized:
311
- self._log("Cannot DAPT a quantized model.", "WARNING")
282
+ self._log("Cannot DAPT a quantized model.", LogLevel.WARNING)
312
283
  return None
313
284
 
314
285
  if params is None:
@@ -443,16 +414,16 @@ class BaseModel(ABC):
443
414
  save_path: str | None = None
444
415
  ) -> None:
445
416
  if not self.model:
446
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
417
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
447
418
  return None
448
419
  if not self.tokenizer:
449
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
420
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
450
421
  return None
451
422
 
452
423
  self._log("Starting fine-tune")
453
424
 
454
425
  if self.model_is_quantized:
455
- self._log("Cannot fine-tune a quantized model.", "WARNING")
426
+ self._log("Cannot fine-tune a quantized model.", LogLevel.WARNING)
456
427
  return None
457
428
 
458
429
  if params is None:
@@ -521,8 +492,8 @@ class BaseModel(ABC):
521
492
  self._model_id = None
522
493
  self._log("Reset successfully.")
523
494
  except Exception as e:
524
- self._log("Couldn't reset model...", "ERROR")
525
- self._log(f"{str(e)}", "DEBUG")
495
+ self._log("Couldn't reset model...", LogLevel.ERROR)
496
+ self._log(f"{str(e)}", LogLevel.DEBUG)
526
497
 
527
498
  def set_seed(self, seed: int) -> None:
528
499
  self._log(f"Setting seed value {seed}")
@@ -1,4 +1,3 @@
1
- import textwrap
2
1
  import threading
3
2
  from functools import partial
4
3
  from time import time
@@ -11,11 +10,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
11
10
  from transformers.models.gpt_oss import GptOssForCausalLM
12
11
  from transformers.utils.quantization_config import Mxfp4Config
13
12
 
14
- from llmflowstack.base.base import BaseModel
15
13
  from llmflowstack.callbacks.stop_on_token import StopOnToken
14
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
16
15
  from llmflowstack.schemas.params import GenerationParams
17
16
  from llmflowstack.utils.exceptions import MissingEssentialProp
18
17
  from llmflowstack.utils.generation_utils import create_generation_params
18
+ from llmflowstack.utils.logging import LogLevel
19
19
 
20
20
 
21
21
  class GPTOSSInput(TypedDict):
@@ -26,7 +26,7 @@ class GPTOSSInput(TypedDict):
26
26
  reasoning_message: str | None
27
27
  reasoning_level: Literal["Low", "Medium", "High"] | None
28
28
 
29
- class GPT_OSS(BaseModel):
29
+ class GPT_OSS(BaseDecoder):
30
30
  model: GptOssForCausalLM | None = None
31
31
  reasoning_level: Literal["Low", "Medium", "High"] = "Low"
32
32
  question_fields = ["input_text", "developer_message", "system_message"]
@@ -36,14 +36,12 @@ class GPT_OSS(BaseModel):
36
36
  self,
37
37
  checkpoint: str | None = None,
38
38
  quantization: bool | None = None,
39
- seed: int | None = None,
40
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
39
+ seed: int | None = None
41
40
  ) -> None:
42
41
  return super().__init__(
43
42
  checkpoint=checkpoint,
44
43
  quantization=quantization,
45
- seed=seed,
46
- log_level=log_level
44
+ seed=seed
47
45
  )
48
46
 
49
47
  def _set_generation_stopping_tokens(
@@ -51,7 +49,7 @@ class GPT_OSS(BaseModel):
51
49
  tokens: list[int]
52
50
  ) -> None:
53
51
  if not self.tokenizer:
54
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
52
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
55
53
  return None
56
54
  encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
57
55
  particular_tokens = encoding.stop_tokens_for_assistant_actions()
@@ -76,7 +74,7 @@ class GPT_OSS(BaseModel):
76
74
  attn_implementation="eager",
77
75
  )
78
76
  except Exception as _:
79
- self._log("Error trying to load the model. Defaulting to load without quantization...", "WARNING")
77
+ self._log("Error trying to load the model. Defaulting to load without quantization...", LogLevel.WARNING)
80
78
  self.model = GptOssForCausalLM.from_pretrained(
81
79
  checkpoint,
82
80
  dtype="auto",
@@ -119,7 +117,11 @@ class GPT_OSS(BaseModel):
119
117
  if expected_answer:
120
118
  assistant_text += f"<|start|>assistant<|channel|>final<|message|>{expected_answer}<|return|>"
121
119
 
122
- return textwrap.dedent(f"""{system_text}{developer_text}<|start|>user<|message|>{data["input_text"]}<|end|>{assistant_text}""")
120
+ return (
121
+ f"{system_text}{developer_text}"
122
+ f"<|start|>user<|message|>{data["input_text"]}<|end|>"
123
+ f"{assistant_text}"
124
+ )
123
125
 
124
126
  def build_input(
125
127
  self,
@@ -154,7 +156,7 @@ class GPT_OSS(BaseModel):
154
156
  params: GenerationParams | None = None
155
157
  ) -> str | None:
156
158
  if self.model is None or self.tokenizer is None:
157
- self._log("Model or Tokenizer missing", "WARNING")
159
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
158
160
  return None
159
161
 
160
162
  self._log(f"Processing received input...'")
@@ -222,7 +224,7 @@ class GPT_OSS(BaseModel):
222
224
  params: GenerationParams | None = None
223
225
  ) -> Iterator[str]:
224
226
  if self.model is None or self.tokenizer is None:
225
- self._log("Model or Tokenizer missing", "WARNING")
227
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
226
228
  if False:
227
229
  yield ""
228
230
  return
@@ -10,12 +10,13 @@ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
10
10
  from transformers.models.gemma3 import Gemma3ForCausalLM
11
11
  from transformers.utils.quantization_config import BitsAndBytesConfig
12
12
 
13
- from llmflowstack.base.base import BaseModel
14
13
  from llmflowstack.callbacks.log_collector import LogCollectorCallback
15
14
  from llmflowstack.callbacks.stop_on_token import StopOnToken
15
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
16
16
  from llmflowstack.schemas.params import GenerationParams, TrainParams
17
17
  from llmflowstack.utils.exceptions import MissingEssentialProp
18
18
  from llmflowstack.utils.generation_utils import create_generation_params
19
+ from llmflowstack.utils.logging import LogLevel
19
20
 
20
21
 
21
22
  class Gemma3Input(TypedDict):
@@ -24,7 +25,7 @@ class Gemma3Input(TypedDict):
24
25
  system_message: str | None
25
26
  image_paths: list[str] | None
26
27
 
27
- class Gemma3(BaseModel):
28
+ class Gemma3(BaseDecoder):
28
29
  model: Gemma3ForCausalLM | None = None
29
30
  question_fields = ["input_text", "system_message"]
30
31
  answer_fields = ["expected_answer"]
@@ -33,14 +34,12 @@ class Gemma3(BaseModel):
33
34
  self,
34
35
  checkpoint: str | None = None,
35
36
  quantization: Literal["4bit"] | None = None,
36
- seed: int | None = None,
37
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
37
+ seed: int | None = None
38
38
  ) -> None:
39
39
  return super().__init__(
40
40
  checkpoint=checkpoint,
41
41
  quantization=quantization,
42
- seed=seed,
43
- log_level=log_level
42
+ seed=seed
44
43
  )
45
44
 
46
45
  def _set_generation_stopping_tokens(
@@ -48,7 +47,7 @@ class Gemma3(BaseModel):
48
47
  tokens: list[int]
49
48
  ) -> None:
50
49
  if not self.tokenizer:
51
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
50
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
52
51
  return None
53
52
  particular_tokens = self.tokenizer.encode("<end_of_turn>")
54
53
  self.stop_token_ids = tokens + particular_tokens
@@ -129,16 +128,16 @@ class Gemma3(BaseModel):
129
128
  save_path: str | None = None
130
129
  ) -> None:
131
130
  if not self.model:
132
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
131
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
133
132
  return None
134
133
  if not self.tokenizer:
135
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
134
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
136
135
  return None
137
136
 
138
137
  self._log("Starting Training")
139
138
 
140
139
  if self.model_is_quantized:
141
- self._log("Cannot traub a quantized model.", "WARNING")
140
+ self._log("Cannot traub a quantized model.", LogLevel.WARNING)
142
141
  return None
143
142
 
144
143
  if params is None:
@@ -195,7 +194,7 @@ class Gemma3(BaseModel):
195
194
  save_at_end = True,
196
195
  save_path: str | None = None
197
196
  ) -> None:
198
- self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
197
+ self._log("Only 'dapt' method is available for this class. Redirecting call to it.", LogLevel.WARNING)
199
198
  return self.dapt(
200
199
  train_dataset=train_dataset,
201
200
  params=params,
@@ -210,7 +209,7 @@ class Gemma3(BaseModel):
210
209
  params: GenerationParams | None = None,
211
210
  ) -> str | None:
212
211
  if self.model is None or self.tokenizer is None:
213
- self._log("Model or Tokenizer missing", "WARNING")
212
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
214
213
  return None
215
214
 
216
215
  self._log(f"Processing received input...'")
@@ -267,7 +266,7 @@ class Gemma3(BaseModel):
267
266
  params: GenerationParams | None = None
268
267
  ) -> Iterator[str]:
269
268
  if self.model is None or self.tokenizer is None:
270
- self._log("Model or Tokenizer missing", "WARNING")
269
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
271
270
  if False:
272
271
  yield ""
273
272
  return
@@ -1,4 +1,3 @@
1
- import textwrap
2
1
  import threading
3
2
  from time import time
4
3
  from typing import Iterator, Literal, TypedDict, cast
@@ -9,11 +8,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
9
8
  from transformers.models.llama import LlamaForCausalLM
10
9
  from transformers.utils.quantization_config import BitsAndBytesConfig
11
10
 
12
- from llmflowstack.base.base import BaseModel
13
11
  from llmflowstack.callbacks.stop_on_token import StopOnToken
12
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
14
13
  from llmflowstack.schemas.params import GenerationParams
15
14
  from llmflowstack.utils.exceptions import MissingEssentialProp
16
15
  from llmflowstack.utils.generation_utils import create_generation_params
16
+ from llmflowstack.utils.logging import LogLevel
17
17
 
18
18
 
19
19
  class LLaMA3Input(TypedDict):
@@ -21,7 +21,7 @@ class LLaMA3Input(TypedDict):
21
21
  expected_answer: str | None
22
22
  system_message: str | None
23
23
 
24
- class LLaMA3(BaseModel):
24
+ class LLaMA3(BaseDecoder):
25
25
  model: LlamaForCausalLM | None = None
26
26
  question_fields = ["input_text", "system_message"]
27
27
  answer_fields = ["expected_answer"]
@@ -30,14 +30,12 @@ class LLaMA3(BaseModel):
30
30
  self,
31
31
  checkpoint: str | None = None,
32
32
  quantization: Literal["4bit", "8bit"] | None = None,
33
- seed: int | None = None,
34
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
33
+ seed: int | None = None
35
34
  ) -> None:
36
35
  return super().__init__(
37
36
  checkpoint=checkpoint,
38
37
  quantization=quantization,
39
- seed=seed,
40
- log_level=log_level
38
+ seed=seed
41
39
  )
42
40
 
43
41
  def _set_generation_stopping_tokens(
@@ -45,7 +43,7 @@ class LLaMA3(BaseModel):
45
43
  tokens: list[int]
46
44
  ) -> None:
47
45
  if not self.tokenizer:
48
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
46
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
49
47
  return None
50
48
  particular_tokens = self.tokenizer.encode("<|eot_id|>")
51
49
  self.stop_token_ids = tokens + particular_tokens
@@ -92,7 +90,7 @@ class LLaMA3(BaseModel):
92
90
 
93
91
  system_message = data.get("system_message", "")
94
92
 
95
- return textwrap.dedent(
93
+ return (
96
94
  f"<|start_header_id|>system<|end_header_id|>{system_message}\n"
97
95
  f"<|eot_id|><|start_header_id|>user<|end_header_id|>{data["input_text"]}\n"
98
96
  f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>{answer}"
@@ -119,7 +117,7 @@ class LLaMA3(BaseModel):
119
117
  params: GenerationParams | None = None
120
118
  ) -> str | None:
121
119
  if self.model is None or self.tokenizer is None:
122
- self._log("Model or Tokenizer missing", "WARNING")
120
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
123
121
  return None
124
122
 
125
123
  self.model
@@ -184,7 +182,7 @@ class LLaMA3(BaseModel):
184
182
  params: GenerationParams | None = None
185
183
  ) -> Iterator[str]:
186
184
  if self.model is None or self.tokenizer is None:
187
- self._log("Model or Tokenizer missing", "WARNING")
185
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
188
186
  if False:
189
187
  yield ""
190
188
  return
@@ -1,7 +1,7 @@
1
1
  import threading
2
2
  from functools import partial
3
3
  from time import time
4
- from typing import Iterator, Literal, TypedDict, cast
4
+ from typing import Iterator, TypedDict, cast
5
5
 
6
6
  import torch
7
7
  from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
@@ -9,12 +9,13 @@ from transformers import (AutoTokenizer, DataCollatorForLanguageModeling,
9
9
  TrainingArguments)
10
10
  from transformers.models.llama4 import Llama4ForCausalLM
11
11
 
12
- from llmflowstack.base.base import BaseModel
13
12
  from llmflowstack.callbacks.log_collector import LogCollectorCallback
14
13
  from llmflowstack.callbacks.stop_on_token import StopOnToken
14
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
15
15
  from llmflowstack.schemas.params import GenerationParams, TrainParams
16
16
  from llmflowstack.utils.exceptions import MissingEssentialProp
17
17
  from llmflowstack.utils.generation_utils import create_generation_params
18
+ from llmflowstack.utils.logging import LogLevel
18
19
 
19
20
 
20
21
  class LLaMA4Input(TypedDict):
@@ -22,7 +23,7 @@ class LLaMA4Input(TypedDict):
22
23
  expected_answer: str | None
23
24
  system_message: str | None
24
25
 
25
- class LLaMA4(BaseModel):
26
+ class LLaMA4(BaseDecoder):
26
27
  model: Llama4ForCausalLM | None = None
27
28
  question_fields = ["input_text", "system_message"]
28
29
  answer_fields = ["expected_answer"]
@@ -30,14 +31,12 @@ class LLaMA4(BaseModel):
30
31
  def __init__(
31
32
  self,
32
33
  checkpoint: str | None = None,
33
- seed: int | None = None,
34
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
34
+ seed: int | None = None
35
35
  ) -> None:
36
36
  return super().__init__(
37
37
  checkpoint=checkpoint,
38
38
  quantization=None,
39
- seed=seed,
40
- log_level=log_level
39
+ seed=seed
41
40
  )
42
41
 
43
42
  def _set_generation_stopping_tokens(
@@ -45,7 +44,7 @@ class LLaMA4(BaseModel):
45
44
  tokens: list[int]
46
45
  ) -> None:
47
46
  if not self.tokenizer:
48
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
47
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
49
48
  return None
50
49
  particular_tokens = self.tokenizer.encode("<|eot|>")
51
50
  self.stop_token_ids = tokens + particular_tokens
@@ -84,7 +83,8 @@ class LLaMA4(BaseModel):
84
83
  system_message = f"<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>"
85
84
 
86
85
  expected_answer = data.get("expected_answer")
87
- answer = f"<|header_start|>assistant<|header_end|>\n\n{expected_answer}<|eot|>" if expected_answer else ""
86
+ answer = "<|header_start|>assistant<|header_end|>\n\n"
87
+ answer += f"{expected_answer}<|eot|>" if expected_answer else ""
88
88
 
89
89
  return (
90
90
  "<|begin_of_text|>"
@@ -118,16 +118,16 @@ class LLaMA4(BaseModel):
118
118
  save_path: str | None = None
119
119
  ) -> None:
120
120
  if not self.model:
121
- self._log("Could not find a model loaded. Try loading a model first.", "WARNING")
121
+ self._log("Could not find a model loaded. Try loading a model first.", LogLevel.WARNING)
122
122
  return None
123
123
  if not self.tokenizer:
124
- self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", "WARNING")
124
+ self._log("Could not find a tokenizer loaded. Try loading a tokenizer first.", LogLevel.WARNING)
125
125
  return None
126
126
 
127
127
  self._log("Starting DAPT")
128
128
 
129
129
  if self.model_is_quantized:
130
- self._log("Cannot DAPT a quantized model.", "WARNING")
130
+ self._log("Cannot DAPT a quantized model.", LogLevel.WARNING)
131
131
  return None
132
132
 
133
133
  if params is None:
@@ -184,7 +184,7 @@ class LLaMA4(BaseModel):
184
184
  save_at_end = True,
185
185
  save_path: str | None = None
186
186
  ) -> None:
187
- self._log("Only 'dapt' method is available for this class. Redirecting call to it.", "WARNING")
187
+ self._log("Only 'dapt' method is available for this class. Redirecting call to it.", LogLevel.WARNING)
188
188
  return self.dapt(
189
189
  train_dataset=train_dataset,
190
190
  params=params,
@@ -199,7 +199,7 @@ class LLaMA4(BaseModel):
199
199
  params: GenerationParams | None = None
200
200
  ) -> str | None:
201
201
  if self.model is None or self.tokenizer is None:
202
- self._log("Model or Tokenizer missing", "WARNING")
202
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
203
203
  return None
204
204
 
205
205
  self.model
@@ -263,7 +263,7 @@ class LLaMA4(BaseModel):
263
263
  params: GenerationParams | None = None
264
264
  ) -> Iterator[str]:
265
265
  if self.model is None or self.tokenizer is None:
266
- self._log("Model or Tokenizer missing", "WARNING")
266
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
267
267
  if False:
268
268
  yield ""
269
269
  return
@@ -1,4 +1,3 @@
1
- import textwrap
2
1
  import threading
3
2
  from functools import partial
4
3
  from time import time
@@ -10,11 +9,12 @@ from transformers import (AutoTokenizer, StoppingCriteriaList,
10
9
  from transformers.models.gemma3 import Gemma3ForCausalLM
11
10
  from transformers.utils.quantization_config import BitsAndBytesConfig
12
11
 
13
- from llmflowstack.base.base import BaseModel
14
12
  from llmflowstack.callbacks.stop_on_token import StopOnToken
13
+ from llmflowstack.decoders.BaseDecoder import BaseDecoder
15
14
  from llmflowstack.schemas.params import GenerationParams
16
15
  from llmflowstack.utils.exceptions import MissingEssentialProp
17
16
  from llmflowstack.utils.generation_utils import create_generation_params
17
+ from llmflowstack.utils.logging import LogLevel
18
18
 
19
19
 
20
20
  class MedGemmaInput(TypedDict):
@@ -22,7 +22,7 @@ class MedGemmaInput(TypedDict):
22
22
  expected_answer: str | None
23
23
  system_message: str | None
24
24
 
25
- class MedGemma(BaseModel):
25
+ class MedGemma(BaseDecoder):
26
26
  model: Gemma3ForCausalLM | None = None
27
27
  can_think = False
28
28
  question_fields = ["input_text", "system_message"]
@@ -32,14 +32,12 @@ class MedGemma(BaseModel):
32
32
  self,
33
33
  checkpoint: str | None = None,
34
34
  quantization: Literal["4bit"] | None = None,
35
- seed: int | None = None,
36
- log_level: Literal["INFO", "DEBUG", "WARNING"] = "INFO",
35
+ seed: int | None = None
37
36
  ) -> None:
38
37
  return super().__init__(
39
38
  checkpoint=checkpoint,
40
39
  quantization=quantization,
41
- seed=seed,
42
- log_level=log_level
40
+ seed=seed
43
41
  )
44
42
 
45
43
  def _set_generation_stopping_tokens(
@@ -47,7 +45,7 @@ class MedGemma(BaseModel):
47
45
  tokens: list[int]
48
46
  ) -> None:
49
47
  if not self.tokenizer:
50
- self._log("Could not set stop tokens - generation may not work...", "WARNING")
48
+ self._log("Could not set stop tokens - generation may not work...", LogLevel.WARNING)
51
49
  return None
52
50
  particular_tokens = self.tokenizer.encode("<end_of_turn>")
53
51
  self.stop_token_ids = tokens + particular_tokens
@@ -128,7 +126,7 @@ class MedGemma(BaseModel):
128
126
  params: GenerationParams | None = None,
129
127
  ) -> str | None:
130
128
  if self.model is None or self.tokenizer is None:
131
- self._log("Model or Tokenizer missing", "WARNING")
129
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
132
130
  return None
133
131
 
134
132
  self._log(f"Processing received input...'")
@@ -196,7 +194,7 @@ class MedGemma(BaseModel):
196
194
  params: GenerationParams | None = None
197
195
  ) -> Iterator[str]:
198
196
  if self.model is None or self.tokenizer is None:
199
- self._log("Model or Tokenizer missing", "WARNING")
197
+ self._log("Model or Tokenizer missing", LogLevel.WARNING)
200
198
  if False:
201
199
  yield ""
202
200
  return
@@ -0,0 +1,278 @@
1
+ import gc
2
+ import uuid
3
+ from logging import getLogger
4
+
5
+ import chromadb
6
+ import chromadb.config
7
+ import torch
8
+ from langchain_chroma import Chroma
9
+ from langchain_core.documents import Document
10
+ from langchain_core.embeddings import Embeddings
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ from llmflowstack.utils.exceptions import MissingEssentialProp
15
+ from llmflowstack.utils.logging import LogLevel
16
+
17
+
18
+ class EncoderWrapper(Embeddings):
19
+ def __init__(
20
+ self,
21
+ model: SentenceTransformer
22
+ ) -> None:
23
+ self.model = model
24
+
25
+ def embed_documents(
26
+ self,
27
+ texts: list[str]
28
+ ) -> list[list[float]]:
29
+ vectors = self.model.encode(texts, task="retrieval", show_progress_bar=False)
30
+ return vectors.tolist()
31
+
32
+ def embed_query(
33
+ self,
34
+ text: str
35
+ ) -> list[float]:
36
+ vectors = self.model.encode(text, task="retrieval", show_progress_bar=False)
37
+ return vectors.tolist()
38
+
39
+ class VectorDatabase:
40
+ encoder: SentenceTransformer | None = None
41
+ collections: dict[str, Chroma] = {}
42
+
43
+ def __init__(
44
+ self,
45
+ checkpoint: str | None = None,
46
+ chunk_size: int = 1000,
47
+ chunk_overlap: int = 200
48
+ ) -> None:
49
+ self.logger = getLogger(f"LLMFlowStack.{self.__class__.__name__}")
50
+
51
+ self.encoder = None
52
+ if checkpoint:
53
+ self.load_encoder(
54
+ checkpoint=checkpoint
55
+ )
56
+
57
+ self.splitter = RecursiveCharacterTextSplitter(
58
+ chunk_size=chunk_size,
59
+ chunk_overlap=chunk_overlap,
60
+ add_start_index=True,
61
+ )
62
+
63
+ def _log(
64
+ self,
65
+ message: str,
66
+ level: LogLevel = LogLevel.INFO,
67
+ ) -> None:
68
+ log_func = getattr(self.logger, level.lower(), None)
69
+ if log_func:
70
+ log_func(message)
71
+ else:
72
+ self.logger.info(message)
73
+
74
+ def load_encoder(
75
+ self,
76
+ checkpoint: str
77
+ ) -> None:
78
+ if self.encoder:
79
+ self._log("A encoder is already loaded. Attempting to reset it.", LogLevel.WARNING)
80
+ self.unload_encoder()
81
+
82
+ self._log(f"Loading encoder on '{checkpoint}'")
83
+ self.encoder = SentenceTransformer(
84
+ checkpoint,
85
+ trust_remote_code=True
86
+ )
87
+
88
+ self._log("Encoder loaded")
89
+
90
+ def unload_encoder(
91
+ self
92
+ ) -> None:
93
+ try:
94
+ del self.encoder
95
+ gc.collect()
96
+ torch.cuda.empty_cache()
97
+ self.encoder = None
98
+ self._log("Reset successfully.")
99
+ except Exception as e:
100
+ self._log("Couldn't reset encoder...", LogLevel.ERROR)
101
+ self._log(f"{str(e)}", LogLevel.DEBUG)
102
+
103
+ def get_collection(
104
+ self,
105
+ collection_name: str = "rag_memory",
106
+ persist_directory: str | None = None
107
+ ) -> None:
108
+ if not self.encoder:
109
+ raise MissingEssentialProp("Could not find encoder.")
110
+
111
+ client_settings = chromadb.config.Settings(
112
+ anonymized_telemetry=False
113
+ )
114
+
115
+ self.collections[collection_name] = Chroma(
116
+ collection_name=collection_name,
117
+ embedding_function=EncoderWrapper(self.encoder),
118
+ persist_directory=persist_directory,
119
+ client_settings=client_settings
120
+ )
121
+
122
+ def validate_collection_name(
123
+ self,
124
+ collection_name: str
125
+ ) -> None:
126
+ if collection_name not in self.collections:
127
+ raise ValueError("Collection name not found in collection")
128
+
129
+ def index_documents(
130
+ self,
131
+ collection_name: str,
132
+ docs: list[Document],
133
+ ids: list[str],
134
+ can_split: bool = True,
135
+ ) -> None:
136
+ self.validate_collection_name(
137
+ collection_name=collection_name
138
+ )
139
+
140
+ for doc, src_id in zip(docs, ids):
141
+ if doc.metadata is None:
142
+ doc.metadata = {}
143
+ doc.metadata["source_id"] = src_id
144
+
145
+ if can_split:
146
+ splits = self.splitter.split_documents(docs)
147
+ else:
148
+ splits = docs
149
+
150
+ split_ids = []
151
+ metadatas = []
152
+ texts = []
153
+
154
+ for i, s in enumerate(splits):
155
+ src = s.metadata.get("source_id", "unknown")
156
+ sid = f"{src}_{i}"
157
+ split_ids.append(sid)
158
+ metadatas.append(s.metadata.copy())
159
+ texts.append(s.page_content)
160
+
161
+ self.collections[collection_name].add_texts(
162
+ texts=texts,
163
+ ids=split_ids,
164
+ metadatas=metadatas
165
+ )
166
+
167
+ def create(
168
+ self,
169
+ collection_name: str,
170
+ information: str,
171
+ other_info: dict[str, str] | None = None,
172
+ doc_id: str | None = None,
173
+ should_index: bool = True,
174
+ can_split: bool = True
175
+ ) -> Document:
176
+ if other_info is None:
177
+ other_info = {}
178
+
179
+ if doc_id is None:
180
+ doc_id = str(uuid.uuid4())
181
+
182
+ metadata = {"source_id": doc_id, **other_info}
183
+ doc = Document(
184
+ page_content=information,
185
+ metadata=metadata
186
+ )
187
+
188
+ if should_index:
189
+ self.index_documents(
190
+ collection_name=collection_name,
191
+ docs=[doc],
192
+ ids=[doc_id],
193
+ can_split=can_split
194
+ )
195
+
196
+ return doc
197
+
198
+ def update(
199
+ self,
200
+ collection_name: str,
201
+ doc_id: str,
202
+ new_information: str,
203
+ other_info: dict[str, str] | None = None
204
+ ) -> Document:
205
+ self.validate_collection_name(
206
+ collection_name=collection_name
207
+ )
208
+
209
+ if other_info is None:
210
+ other_info = {}
211
+
212
+ documents_to_delete = self.collections[collection_name].get(
213
+ where={
214
+ "source_id": doc_id
215
+ }
216
+ )
217
+
218
+ ids_to_delete = documents_to_delete.get("ids", [])
219
+
220
+ if ids_to_delete:
221
+ self.collections[collection_name].delete(ids=ids_to_delete)
222
+
223
+ return self.create(
224
+ collection_name=collection_name,
225
+ information=new_information,
226
+ other_info=other_info,
227
+ doc_id=doc_id
228
+ )
229
+
230
+ def delete(
231
+ self,
232
+ collection_name: str,
233
+ doc_id: str
234
+ ) -> None:
235
+ self.validate_collection_name(
236
+ collection_name=collection_name
237
+ )
238
+
239
+ self.collections[collection_name].delete(ids=[doc_id])
240
+
241
+ def rquery(
242
+ self,
243
+ collection_name: str,
244
+ query: str,
245
+ k: int = 4,
246
+ filter: dict | None = None
247
+ ) -> list[Document]:
248
+ self.validate_collection_name(
249
+ collection_name=collection_name
250
+ )
251
+
252
+ return self.collections[collection_name].similarity_search(
253
+ query=query,
254
+ k=k,
255
+ filter=filter
256
+ )
257
+
258
+ def query(
259
+ self,
260
+ collection_name: str,
261
+ query: str,
262
+ k: int = 4,
263
+ filter: dict | None = None
264
+ ) -> str:
265
+ self.validate_collection_name(
266
+ collection_name=collection_name
267
+ )
268
+
269
+ if filter:
270
+ docs = self.collections[collection_name].similarity_search(
271
+ query=query,
272
+ k=k,
273
+ filter=filter
274
+ )
275
+ else:
276
+ docs = self.collections[collection_name].similarity_search(query, k=k)
277
+
278
+ return "\n\n".join(doc.page_content for doc in docs)
@@ -0,0 +1,5 @@
1
+ from .VectorDatabase import VectorDatabase
2
+
3
+ __all__ = [
4
+ "VectorDatabase"
5
+ ]
@@ -0,0 +1,8 @@
1
+ from enum import Enum
2
+
3
+
4
+ class LogLevel(str, Enum):
5
+ INFO = "info"
6
+ WARNING = "warning"
7
+ ERROR = "error"
8
+ DEBUG = "debug"
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "llmflowstack"
7
- version = "1.1.3"
7
+ version = "1.2.0"
8
8
  authors = [
9
9
  { name = "Gustavo Henrique Ferreira Cruz", email = "gustavohferreiracruz@gmail.com" }
10
10
  ]
@@ -17,7 +17,6 @@ dependencies = [
17
17
  "accelerate",
18
18
  "bert-score",
19
19
  "bitsandbytes",
20
- "colorama",
21
20
  "chromadb",
22
21
  "datasets",
23
22
  "evaluate",
File without changes
@@ -1,5 +0,0 @@
1
- from .pipeline import RAGPipeline
2
-
3
- __all__ = [
4
- "RAGPipeline"
5
- ]
@@ -1,279 +0,0 @@
1
- import uuid
2
-
3
- import chromadb
4
- import chromadb.config
5
- from langchain_chroma import Chroma
6
- from langchain_core.documents import Document
7
- from langchain_core.embeddings import Embeddings
8
- from langchain_text_splitters import RecursiveCharacterTextSplitter
9
- from sentence_transformers import SentenceTransformer
10
-
11
-
12
- class EncoderWrapper(Embeddings):
13
- def __init__(
14
- self,
15
- model: SentenceTransformer
16
- ) -> None:
17
- self.model = model
18
-
19
- def embed_documents(
20
- self,
21
- texts: list[str]
22
- ) -> list[list[float]]:
23
- vectors = self.model.encode(texts, task="retrieval", show_progress_bar=False)
24
- return vectors.tolist()
25
-
26
- def embed_query(
27
- self,
28
- text: str
29
- ) -> list[float]:
30
- vectors = self.model.encode(text, task="retrieval", show_progress_bar=False)
31
- return vectors.tolist()
32
-
33
- class RAGPipeline:
34
- """
35
- A modular Retrieval-Augmented Generation (RAG) pipeline for embedding, indexing, and retrieving scientific or textual data using SentenceTransformers and Chroma as a vector store.
36
-
37
- Supports both persistent (disk-based) and transient (in-memory) modes depending on whether `persist_directory` is provided.
38
- """
39
- def __init__(
40
- self,
41
- checkpoint: str,
42
- collection_name: str = "rag_memory",
43
- persist_directory: str | None = None,
44
- chunk_size: int = 1000,
45
- chunk_overlap: int = 200
46
- ) -> None:
47
- """
48
- Initializes the RAG pipeline.
49
-
50
- Args:
51
- checkpoint (str): Path or name of the SentenceTransformer checkpoint.
52
- collection_name (str): Name of the Chroma collection to create or load.
53
- persist_directory (str | None): Directory where the vector database is stored. If None, all data is kept in-memory and discarded after the session ends.
54
- chunk_size (int): Maximum size (in characters) for text chunks during indexing.
55
- chunk_overlap (int): Overlap (in characters) between consecutive text chunks.
56
- """
57
- self.encoder = SentenceTransformer(checkpoint, trust_remote_code=True)
58
-
59
- client_settings = chromadb.config.Settings(
60
- anonymized_telemetry=False
61
- )
62
-
63
- self.collection = Chroma(
64
- collection_name=collection_name,
65
- embedding_function=EncoderWrapper(self.encoder),
66
- persist_directory=persist_directory,
67
- client_settings=client_settings
68
- )
69
-
70
- self.splitter = RecursiveCharacterTextSplitter(
71
- chunk_size=chunk_size,
72
- chunk_overlap=chunk_overlap,
73
- add_start_index=True,
74
- )
75
-
76
- def index_documents(
77
- self,
78
- docs: list[Document],
79
- ids: list[str],
80
- can_split: bool = True
81
- ) -> None:
82
- """
83
- Indexes a list of documents into the Chroma vector store.
84
-
85
- Each document is assigned a unique `source_id` and, optionally, split into smaller chunks for more granular retrieval. Each resulting chunk is embedded and stored with its metadata for later similarity search.
86
-
87
- Args:
88
- docs (list[Document]): List of LangChain `Document` objects to index.
89
- ids (list[str]): Unique identifiers corresponding to each document.
90
- can_split (bool): Whether to split documents into smaller chunks before
91
- indexing. Set to False to index each document as a single entry
92
- (e.g., for short or self-contained texts).
93
-
94
- Returns:
95
- None
96
- """
97
- for doc, src_id in zip(docs, ids):
98
- if doc.metadata is None:
99
- doc.metadata = {}
100
- doc.metadata["source_id"] = src_id
101
-
102
- if can_split:
103
- splits = self.splitter.split_documents(docs)
104
- else:
105
- splits = docs
106
-
107
- split_ids = []
108
- metadatas = []
109
- texts = []
110
-
111
- for i, s in enumerate(splits):
112
- src = s.metadata.get("source_id", "unknown")
113
- sid = f"{src}_{i}"
114
- split_ids.append(sid)
115
- metadatas.append(s.metadata.copy())
116
- texts.append(s.page_content)
117
-
118
- self.collection.add_texts(
119
- texts=texts,
120
- ids=split_ids,
121
- metadatas=metadatas
122
- )
123
-
124
- def create(
125
- self,
126
- information: str,
127
- other_info: dict[str, str] | None = None,
128
- doc_id: str | None = None,
129
- should_index: bool = True,
130
- can_split: bool = True
131
- ) -> Document:
132
- """
133
- Creates a new `Document` and optionally indexes it in the collection.
134
-
135
- This is a convenience method that wraps both document creation and embedding/indexing in one step. Metadata fields are merged into the document and can include any descriptive information (e.g., title, DOI, year).
136
-
137
- Args:
138
- information (str): Main textual content of the document.
139
- other_info (dict[str, str] | None): Optional metadata fields to include.
140
- doc_id (str | None): Custom document identifier. If None, a UUID is generated.
141
- should_index (bool): Whether to immediately add the document to the vector store.
142
- can_split (bool): Whether to allow splitting before indexing.
143
-
144
- Returns:
145
- Document: The created LangChain `Document` object (indexed if specified).
146
- """
147
- if other_info is None:
148
- other_info = {}
149
-
150
- if doc_id is None:
151
- doc_id = str(uuid.uuid4())
152
-
153
- metadata = {"source_id": doc_id, **other_info}
154
- doc = Document(
155
- page_content=information,
156
- metadata=metadata
157
- )
158
-
159
- if should_index:
160
- self.index_documents(
161
- docs=[doc],
162
- ids=[doc_id],
163
- can_split=can_split
164
- )
165
-
166
- return doc
167
-
168
- def update(
169
- self,
170
- doc_id: str,
171
- new_information: str,
172
- other_info: dict[str, str] | None = None
173
- ) -> Document:
174
- """
175
- Updates an existing document in the collection with new content and metadata.
176
-
177
- All vector entries associated with the provided `doc_id` are deleted, and a new document is created and re-indexed in their place. This ensures that embeddings remain consistent with the latest text content.
178
-
179
- Args:
180
- doc_id (str): Identifier of the document to update.
181
- new_information (str): Updated text content for the document.
182
- other_info (dict[str, str] | None): Optional new metadata to associate.
183
-
184
- Returns:
185
- Document: The newly created (updated) `Document` object.
186
- """
187
- if other_info is None:
188
- other_info = {}
189
-
190
- documents_to_delete = self.collection.get(
191
- where={
192
- "source_id": doc_id
193
- }
194
- )
195
-
196
- ids_to_delete = documents_to_delete.get("ids", [])
197
-
198
- if ids_to_delete:
199
- self.collection.delete(ids=ids_to_delete)
200
-
201
- return self.create(
202
- information=new_information,
203
- other_info=other_info,
204
- doc_id=doc_id
205
- )
206
-
207
- def delete(
208
- self,
209
- doc_id: str
210
- ) -> None:
211
- """
212
- Deletes all indexed entries associated with a specific document ID.
213
-
214
- Removes all vectors and metadata tied to the provided `doc_id` from the collection. Use this to completely erase a document's content from the indexed database.
215
-
216
- Args:
217
- doc_id (str): Identifier of the document to delete.
218
-
219
- Returns:
220
- None
221
- """
222
- self.collection.delete(ids=[doc_id])
223
-
224
- def rquery(
225
- self,
226
- query: str,
227
- k: int = 4,
228
- filter: dict | None = None
229
- ) -> list[Document]:
230
- """
231
- Perform a **raw semantic search** on the collection.
232
-
233
- This method queries the vector store using the provided text query and returns the top-`k` most similar `Document` objects, optionally filtered by metadata.
234
-
235
- Args:
236
- query (str): The natural-language query text to embed and search for.
237
- k (int, optional): Number of top results to return. Defaults to 4.
238
- filter (dict | None, optional): Metadata filter applied during search
239
- (e.g., {"type": "article"}). Defaults to None.
240
-
241
- Returns:
242
- list[Document]: A list of matching documents sorted by similarity score.
243
- """
244
- return self.collection.similarity_search(
245
- query=query,
246
- k=k,
247
- filter=filter
248
- )
249
-
250
- def query(
251
- self,
252
- query: str,
253
- k: int = 4,
254
- filter: dict | None = None
255
- ) -> str:
256
- """
257
- Perform a **semantic search** and return the combined text content.
258
-
259
- This method wraps `rquery()` and concatenates the retrieved document contents into a single string, suitable for direct use in downstream LLM prompts or text processing.
260
-
261
- Args:
262
- query (str): The natural-language query text to search for.
263
- k (int, optional): Number of top results to return. Defaults to 4.
264
- filter (dict | None, optional): Metadata filter applied during search. If None, all documents are considered.
265
-
266
- Returns:
267
- str: A newline-separated string containing the page contents of
268
- the retrieved documents.
269
- """
270
- if filter:
271
- docs = self.collection.similarity_search(
272
- query=query,
273
- k=k,
274
- filter=filter
275
- )
276
- else:
277
- docs = self.collection.similarity_search(query, k=k)
278
-
279
- return "\n\n".join(doc.page_content for doc in docs)
File without changes
File without changes
File without changes