arize 8.0.0a9__py3-none-any.whl → 8.0.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arize/client.py CHANGED
@@ -12,11 +12,13 @@ if TYPE_CHECKING:
12
12
  from arize.spans.client import SpansClient
13
13
 
14
14
 
15
+ # TODO(Kiko): experimental/datasets must be adapted into the datasets subclient
16
+ # TODO(Kiko): experimental/prompt hub is missing
17
+ # TODO(Kiko): exporter/utils/schema_parser is missing
15
18
  # TODO(Kiko): Go through main APIs and add CtxAdapter where missing
16
19
  # TODO(Kiko): Search and handle other TODOs
17
20
  # TODO(Kiko): Go over **every file** and do not import anything at runtime, use `if TYPE_CHECKING`
18
21
  # with `from __future__ import annotations` (must include for Python < 3.11)
19
- # TODO(Kiko): MIMIC Explainer not done
20
22
  # TODO(Kiko): Go over docstrings
21
23
  class ArizeClient(LazySubclientsMixin):
22
24
  """
@@ -0,0 +1,4 @@
1
+ from arize.embeddings.auto_generator import EmbeddingGenerator
2
+ from arize.embeddings.usecases import UseCases
3
+
4
+ __all__ = ["EmbeddingGenerator", "UseCases"]
@@ -0,0 +1,108 @@
1
+ from typing import Any
2
+
3
+ import pandas as pd
4
+
5
+ from arize.embeddings import constants
6
+ from arize.embeddings.base_generators import BaseEmbeddingGenerator
7
+ from arize.embeddings.constants import (
8
+ CV_PRETRAINED_MODELS,
9
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
10
+ DEFAULT_CV_OBJECT_DETECTION_MODEL,
11
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
12
+ DEFAULT_NLP_SUMMARIZATION_MODEL,
13
+ DEFAULT_TABULAR_MODEL,
14
+ NLP_PRETRAINED_MODELS,
15
+ )
16
+ from arize.embeddings.cv_generators import (
17
+ EmbeddingGeneratorForCVImageClassification,
18
+ EmbeddingGeneratorForCVObjectDetection,
19
+ )
20
+ from arize.embeddings.nlp_generators import (
21
+ EmbeddingGeneratorForNLPSequenceClassification,
22
+ EmbeddingGeneratorForNLPSummarization,
23
+ )
24
+ from arize.embeddings.tabular_generators import (
25
+ EmbeddingGeneratorForTabularFeatures,
26
+ )
27
+ from arize.embeddings.usecases import UseCases
28
+
29
+ UseCaseLike = str | UseCases.NLP | UseCases.CV | UseCases.STRUCTURED
30
+
31
+
32
+ class EmbeddingGenerator:
33
+ def __init__(self, **kwargs: str):
34
+ raise OSError(
35
+ f"{self.__class__.__name__} is designed to be instantiated using the "
36
+ f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method."
37
+ )
38
+
39
+ @staticmethod
40
+ def from_use_case(
41
+ use_case: UseCaseLike, **kwargs: Any
42
+ ) -> BaseEmbeddingGenerator:
43
+ if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
44
+ return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
45
+ elif use_case == UseCases.NLP.SUMMARIZATION:
46
+ return EmbeddingGeneratorForNLPSummarization(**kwargs)
47
+ elif use_case == UseCases.CV.IMAGE_CLASSIFICATION:
48
+ return EmbeddingGeneratorForCVImageClassification(**kwargs)
49
+ elif use_case == UseCases.CV.OBJECT_DETECTION:
50
+ return EmbeddingGeneratorForCVObjectDetection(**kwargs)
51
+ elif use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
52
+ return EmbeddingGeneratorForTabularFeatures(**kwargs)
53
+ else:
54
+ raise ValueError(f"Invalid use case {use_case}")
55
+
56
+ @classmethod
57
+ def list_default_models(cls) -> pd.DataFrame:
58
+ df = pd.DataFrame(
59
+ {
60
+ "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"],
61
+ "Usecase": [
62
+ UseCases.NLP.SEQUENCE_CLASSIFICATION.name,
63
+ UseCases.NLP.SUMMARIZATION.name,
64
+ UseCases.CV.IMAGE_CLASSIFICATION.name,
65
+ UseCases.CV.OBJECT_DETECTION.name,
66
+ UseCases.STRUCTURED.TABULAR_EMBEDDINGS.name,
67
+ ],
68
+ "Model Name": [
69
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
70
+ DEFAULT_NLP_SUMMARIZATION_MODEL,
71
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
72
+ DEFAULT_CV_OBJECT_DETECTION_MODEL,
73
+ DEFAULT_TABULAR_MODEL,
74
+ ],
75
+ }
76
+ )
77
+ df.sort_values(
78
+ by=[col for col in df.columns], ascending=True, inplace=True
79
+ )
80
+ return df.reset_index(drop=True)
81
+
82
+ @classmethod
83
+ def list_pretrained_models(cls) -> pd.DataFrame:
84
+ data = {
85
+ "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS]
86
+ + ["CV" for _ in CV_PRETRAINED_MODELS],
87
+ "Architecture": [
88
+ cls.__parse_model_arch(model)
89
+ for model in NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS
90
+ ],
91
+ "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS,
92
+ }
93
+ df = pd.DataFrame(data)
94
+ df.sort_values(
95
+ by=[col for col in df.columns], ascending=True, inplace=True
96
+ )
97
+ return df.reset_index(drop=True)
98
+
99
+ @staticmethod
100
+ def __parse_model_arch(model_name: str) -> str:
101
+ if constants.GPT.lower() in model_name.lower():
102
+ return constants.GPT
103
+ elif constants.BERT.lower() in model_name.lower():
104
+ return constants.BERT
105
+ elif constants.VIT.lower() in model_name.lower():
106
+ return constants.VIT
107
+ else:
108
+ raise ValueError("Invalid model_name, unknown architecture.")
@@ -0,0 +1,255 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from enum import Enum
4
+ from functools import partial
5
+ from typing import Dict, List, Union, cast
6
+
7
+ import pandas as pd
8
+
9
+ import arize.embeddings.errors as err
10
+ from arize.embeddings.constants import IMPORT_ERROR_MESSAGE
11
+
12
+ try:
13
+ import torch
14
+ from datasets import Dataset
15
+ from PIL import Image
16
+ from transformers import ( # type: ignore
17
+ AutoImageProcessor,
18
+ AutoModel,
19
+ AutoTokenizer,
20
+ BatchEncoding,
21
+ )
22
+ from transformers.utils import logging as transformer_logging
23
+ except ImportError as e:
24
+ raise ImportError(IMPORT_ERROR_MESSAGE) from e
25
+
26
+ import logging
27
+
28
+ logger = logging.getLogger(__name__)
29
+ transformer_logging.set_verbosity(50)
30
+ transformer_logging.enable_progress_bar()
31
+
32
+
33
+ class BaseEmbeddingGenerator(ABC):
34
+ def __init__(
35
+ self, use_case: Enum, model_name: str, batch_size: int = 100, **kwargs
36
+ ):
37
+ self.__use_case = self._parse_use_case(use_case=use_case)
38
+ self.__model_name = model_name
39
+ self.__device = self.select_device()
40
+ self.__batch_size = batch_size
41
+ logger.info(f"Downloading pre-trained model '{self.model_name}'")
42
+ try:
43
+ self.__model = AutoModel.from_pretrained(
44
+ self.model_name, **kwargs
45
+ ).to(self.device)
46
+ except OSError as e:
47
+ raise err.HuggingFaceRepositoryNotFound(model_name) from e
48
+ except Exception as e:
49
+ raise e
50
+
51
+ @abstractmethod
52
+ def generate_embeddings(self, **kwargs) -> pd.Series: ...
53
+
54
+ def select_device(self) -> torch.device:
55
+ if torch.cuda.is_available():
56
+ return torch.device("cuda")
57
+ elif torch.backends.mps.is_available():
58
+ return torch.device("mps")
59
+ else:
60
+ logger.warning(
61
+ "No available GPU has been detected. The use of GPU acceleration is "
62
+ "strongly recommended. You can check for GPU availability by running "
63
+ "`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
64
+ )
65
+ return torch.device("cpu")
66
+
67
+ @property
68
+ def use_case(self) -> str:
69
+ return self.__use_case
70
+
71
+ @property
72
+ def model_name(self) -> str:
73
+ return self.__model_name
74
+
75
+ @property
76
+ def model(self):
77
+ return self.__model
78
+
79
+ @property
80
+ def device(self) -> torch.device:
81
+ return self.__device
82
+
83
+ @property
84
+ def batch_size(self) -> int:
85
+ return self.__batch_size
86
+
87
+ @batch_size.setter
88
+ def batch_size(self, new_batch_size: int) -> None:
89
+ err_message = "New batch size should be an integer greater than 0."
90
+ if not isinstance(new_batch_size, int):
91
+ raise TypeError(err_message)
92
+ elif new_batch_size <= 0:
93
+ raise ValueError(err_message)
94
+ else:
95
+ self.__batch_size = new_batch_size
96
+ logger.info(f"Batch size has been set to {new_batch_size}.")
97
+
98
+ @staticmethod
99
+ def _parse_use_case(use_case: Enum) -> str:
100
+ uc_area = use_case.__class__.__name__.split("UseCases")[0]
101
+ uc_task = use_case.name
102
+ return f"{uc_area}.{uc_task}"
103
+
104
+ def _get_embedding_vector(
105
+ self, batch: Dict[str, torch.Tensor], method
106
+ ) -> Dict[str, torch.Tensor]:
107
+ with torch.no_grad():
108
+ outputs = self.model(**batch)
109
+ # (batch_size, seq_length/or/num_tokens, hidden_size)
110
+ if method == "cls_token": # Select CLS token vector
111
+ embeddings = outputs.last_hidden_state[:, 0, :]
112
+ elif method == "avg_token": # Select avg token vector
113
+ embeddings = torch.mean(outputs.last_hidden_state, 1)
114
+ else:
115
+ raise ValueError(f"Invalid method = {method}")
116
+ return {"embedding_vector": embeddings.cpu().numpy().astype(float)}
117
+
118
+ @staticmethod
119
+ def check_invalid_index(field: Union[pd.Series, pd.DataFrame]) -> None:
120
+ if (field.index != field.reset_index(drop=True).index).any():
121
+ if isinstance(field, pd.DataFrame):
122
+ raise err.InvalidIndexError("DataFrame")
123
+ else:
124
+ raise err.InvalidIndexError(str(field.name))
125
+
126
+ @abstractmethod
127
+ def __repr__(self) -> str:
128
+ pass
129
+
130
+
131
+ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
132
+ def __repr__(self) -> str:
133
+ return (
134
+ f"{self.__class__.__name__}(\n"
135
+ f" use_case={self.use_case},\n"
136
+ f" model_name='{self.model_name}',\n"
137
+ f" tokenizer_max_length={self.tokenizer_max_length},\n"
138
+ f" tokenizer={self.tokenizer.__class__},\n"
139
+ f" model={self.model.__class__},\n"
140
+ f" batch_size={self.batch_size},\n"
141
+ f")"
142
+ )
143
+
144
+ def __init__(
145
+ self,
146
+ use_case: Enum,
147
+ model_name: str,
148
+ tokenizer_max_length: int = 512,
149
+ **kwargs,
150
+ ):
151
+ super().__init__(use_case=use_case, model_name=model_name, **kwargs)
152
+ self.__tokenizer_max_length = tokenizer_max_length
153
+ # We don't check for the tokenizer's existence since it is coupled with the corresponding model
154
+ # We check the model's existence in `BaseEmbeddingGenerator`
155
+ logger.info(f"Downloading tokenizer for '{self.model_name}'")
156
+ self.__tokenizer = AutoTokenizer.from_pretrained(
157
+ self.model_name, model_max_length=self.tokenizer_max_length
158
+ )
159
+
160
+ @property
161
+ def tokenizer(self):
162
+ return self.__tokenizer
163
+
164
+ @property
165
+ def tokenizer_max_length(self) -> int:
166
+ return self.__tokenizer_max_length
167
+
168
+ def tokenize(
169
+ self, batch: Dict[str, List[str]], text_feat_name: str
170
+ ) -> BatchEncoding:
171
+ return self.tokenizer(
172
+ batch[text_feat_name],
173
+ padding=True,
174
+ truncation=True,
175
+ max_length=self.tokenizer_max_length,
176
+ return_tensors="pt",
177
+ ).to(self.device)
178
+
179
+
180
+ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
181
+ def __repr__(self) -> str:
182
+ return (
183
+ f"{self.__class__.__name__}(\n"
184
+ f" use_case={self.use_case},\n"
185
+ f" model_name='{self.model_name}',\n"
186
+ f" image_processor={self.image_processor.__class__},\n"
187
+ f" model={self.model.__class__},\n"
188
+ f" batch_size={self.batch_size},\n"
189
+ f")"
190
+ )
191
+
192
+ def __init__(self, use_case: Enum, model_name: str, **kwargs):
193
+ super().__init__(use_case=use_case, model_name=model_name, **kwargs)
194
+ logger.info("Downloading image processor")
195
+ # We don't check for the image processor's existence since it is coupled with the corresponding model
196
+ # We check the model's existence in `BaseEmbeddingGenerator`
197
+ self.__image_processor = AutoImageProcessor.from_pretrained(
198
+ self.model_name
199
+ )
200
+
201
+ @property
202
+ def image_processor(self):
203
+ return self.__image_processor
204
+
205
+ @staticmethod
206
+ def open_image(image_path: str) -> Image.Image:
207
+ if not os.path.exists(image_path):
208
+ raise ValueError(f"Cannot find image {image_path}")
209
+ return Image.open(image_path).convert("RGB")
210
+
211
+ def preprocess_image(
212
+ self, batch: Dict[str, List[str]], local_image_feat_name: str
213
+ ):
214
+ return self.image_processor(
215
+ [
216
+ self.open_image(image_path)
217
+ for image_path in batch[local_image_feat_name]
218
+ ],
219
+ return_tensors="pt",
220
+ ).to(self.device)
221
+
222
+ def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
223
+ """
224
+ Obtain embedding vectors from your image data using pre-trained image models.
225
+
226
+ :param local_image_path_col: a pandas Series containing the local path to the images to
227
+ be used to generate the embedding vectors.
228
+ :return: a pandas Series containing the embedding vectors.
229
+ """
230
+ if not isinstance(local_image_path_col, pd.Series):
231
+ raise TypeError(
232
+ "local_image_path_col_name must be pandas Series object"
233
+ )
234
+ self.check_invalid_index(field=local_image_path_col)
235
+
236
+ # Validate that there are no null image paths
237
+ if local_image_path_col.isnull().any():
238
+ raise ValueError(
239
+ "There can't be any null values in the local_image_path_col series"
240
+ )
241
+
242
+ ds = Dataset.from_dict({"local_path": local_image_path_col})
243
+ ds.set_transform(
244
+ partial(
245
+ self.preprocess_image,
246
+ local_image_feat_name="local_path",
247
+ )
248
+ )
249
+ logger.info("Generating embedding vectors")
250
+ ds = ds.map(
251
+ lambda batch: self._get_embedding_vector(batch, "avg_token"),
252
+ batched=True,
253
+ batch_size=self.batch_size,
254
+ )
255
+ return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
@@ -0,0 +1,34 @@
1
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL = "distilbert-base-uncased"
2
+ DEFAULT_NLP_SUMMARIZATION_MODEL = "distilbert-base-uncased"
3
+ DEFAULT_TABULAR_MODEL = "distilbert-base-uncased"
4
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL = "google/vit-base-patch32-224-in21k"
5
+ DEFAULT_CV_OBJECT_DETECTION_MODEL = "facebook/detr-resnet-101"
6
+ NLP_PRETRAINED_MODELS = [
7
+ "bert-base-cased",
8
+ "bert-base-uncased",
9
+ "bert-large-cased",
10
+ "bert-large-uncased",
11
+ "distilbert-base-cased",
12
+ "distilbert-base-uncased",
13
+ "xlm-roberta-base",
14
+ "xlm-roberta-large",
15
+ ]
16
+
17
+ CV_PRETRAINED_MODELS = [
18
+ "google/vit-base-patch16-224-in21k",
19
+ "google/vit-base-patch16-384",
20
+ "google/vit-base-patch32-224-in21k",
21
+ "google/vit-base-patch32-384",
22
+ "google/vit-large-patch16-224-in21k",
23
+ "google/vit-large-patch16-384",
24
+ "google/vit-large-patch32-224-in21k",
25
+ "google/vit-large-patch32-384",
26
+ ]
27
+ IMPORT_ERROR_MESSAGE = (
28
+ "To enable embedding generation, the arize module must be installed with "
29
+ "extra dependencies. Run: pip install 'arize[auto-embeddings]'."
30
+ )
31
+
32
+ GPT = "GPT"
33
+ BERT = "BERT"
34
+ VIT = "ViT"
@@ -0,0 +1,28 @@
1
+ from arize.embeddings.base_generators import CVEmbeddingGenerator
2
+ from arize.embeddings.constants import (
3
+ DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
4
+ DEFAULT_CV_OBJECT_DETECTION_MODEL,
5
+ )
6
+ from arize.embeddings.usecases import UseCases
7
+
8
+
9
+ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
10
+ def __init__(
11
+ self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs
12
+ ):
13
+ super().__init__(
14
+ use_case=UseCases.CV.IMAGE_CLASSIFICATION,
15
+ model_name=model_name,
16
+ **kwargs,
17
+ )
18
+
19
+
20
+ class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
21
+ def __init__(
22
+ self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs
23
+ ):
24
+ super().__init__(
25
+ use_case=UseCases.CV.OBJECT_DETECTION,
26
+ model_name=model_name,
27
+ **kwargs,
28
+ )
@@ -0,0 +1,41 @@
1
+ class InvalidIndexError(Exception):
2
+ def __repr__(self) -> str:
3
+ return "Invalid_Index_Error"
4
+
5
+ def __str__(self) -> str:
6
+ return self.error_message()
7
+
8
+ def __init__(self, field_name: str) -> None:
9
+ self.field_name = field_name
10
+
11
+ def error_message(self) -> str:
12
+ if self.field_name == "DataFrame":
13
+ return (
14
+ f"The index of the {self.field_name} is invalid; "
15
+ f"reset the index by using df.reset_index(drop=True, inplace=True)"
16
+ )
17
+ else:
18
+ return (
19
+ f"The index of the Series given by the column '{self.field_name}' is invalid; "
20
+ f"reset the index by using df.reset_index(drop=True, inplace=True)"
21
+ )
22
+
23
+
24
+ class HuggingFaceRepositoryNotFound(Exception):
25
+ def __repr__(self) -> str:
26
+ return "HuggingFace_Repository_Not_Found_Error"
27
+
28
+ def __str__(self) -> str:
29
+ return self.error_message()
30
+
31
+ def __init__(self, model_name: str) -> None:
32
+ self.model_name = model_name
33
+
34
+ def error_message(self) -> str:
35
+ return (
36
+ f"The given model name '{self.model_name}' is not a valid model identifier listed on "
37
+ "'https://huggingface.co/models'. "
38
+ "If this is a private repository, log in with `huggingface-cli login` or importing "
39
+ "`login` from `huggingface_hub` if you are using a notebook. "
40
+ "Learn more in https://huggingface.co/docs/huggingface_hub/quick-start#login"
41
+ )
@@ -0,0 +1,111 @@
1
+ import logging
2
+ from functools import partial
3
+ from typing import Optional, cast
4
+
5
+ import pandas as pd
6
+
7
+ from arize.embeddings.base_generators import NLPEmbeddingGenerator
8
+ from arize.embeddings.constants import (
9
+ DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
10
+ DEFAULT_NLP_SUMMARIZATION_MODEL,
11
+ IMPORT_ERROR_MESSAGE,
12
+ )
13
+ from arize.embeddings.usecases import UseCases
14
+
15
+ try:
16
+ from datasets import Dataset
17
+ except ImportError:
18
+ raise ImportError(IMPORT_ERROR_MESSAGE) from None
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
25
+ def __init__(
26
+ self,
27
+ model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
28
+ **kwargs,
29
+ ):
30
+ super().__init__(
31
+ use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
32
+ model_name=model_name,
33
+ **kwargs,
34
+ )
35
+
36
+ def generate_embeddings(
37
+ self,
38
+ text_col: pd.Series,
39
+ class_label_col: Optional[pd.Series] = None,
40
+ ) -> pd.Series:
41
+ """
42
+ Obtain embedding vectors from your text data using pre-trained large language models.
43
+
44
+ :param text_col: a pandas Series containing the different pieces of text.
45
+ :param class_label_col: if this column is passed, the sentence "The classification label
46
+ is <class_label>" will be appended to the text in the `text_col`.
47
+ :return: a pandas Series containing the embedding vectors.
48
+ """
49
+ if not isinstance(text_col, pd.Series):
50
+ raise TypeError("text_col must be a pandas Series")
51
+
52
+ self.check_invalid_index(field=text_col)
53
+
54
+ if class_label_col is not None:
55
+ if not isinstance(class_label_col, pd.Series):
56
+ raise TypeError("class_label_col must be a pandas Series")
57
+ df = pd.concat(
58
+ {"text": text_col, "class_label": class_label_col}, axis=1
59
+ )
60
+ prepared_text_col = df.apply(
61
+ lambda row: f" The classification label is {row['class_label']}. {row['text']}",
62
+ axis=1,
63
+ )
64
+ ds = Dataset.from_dict({"text": prepared_text_col})
65
+ else:
66
+ ds = Dataset.from_dict({"text": text_col})
67
+
68
+ ds.set_transform(partial(self.tokenize, text_feat_name="text"))
69
+ logger.info("Generating embedding vectors")
70
+ ds = ds.map(
71
+ lambda batch: self._get_embedding_vector(batch, "cls_token"),
72
+ batched=True,
73
+ batch_size=self.batch_size,
74
+ )
75
+ return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
76
+
77
+
78
+ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
79
+ def __init__(
80
+ self, model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL, **kwargs
81
+ ):
82
+ super().__init__(
83
+ use_case=UseCases.NLP.SUMMARIZATION,
84
+ model_name=model_name,
85
+ **kwargs,
86
+ )
87
+
88
+ def generate_embeddings(
89
+ self,
90
+ text_col: pd.Series,
91
+ ) -> pd.Series:
92
+ """
93
+ Obtain embedding vectors from your text data using pre-trained large language models.
94
+
95
+ :param text_col: a pandas Series containing the different pieces of text.
96
+ :return: a pandas Series containing the embedding vectors.
97
+ """
98
+ if not isinstance(text_col, pd.Series):
99
+ raise TypeError("text_col must be a pandas Series")
100
+ self.check_invalid_index(field=text_col)
101
+
102
+ ds = Dataset.from_dict({"text": text_col})
103
+
104
+ ds.set_transform(partial(self.tokenize, text_feat_name="text"))
105
+ logger.info("Generating embedding vectors")
106
+ ds = ds.map(
107
+ lambda batch: self._get_embedding_vector(batch, "cls_token"),
108
+ batched=True,
109
+ batch_size=self.batch_size,
110
+ )
111
+ return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]