ddi-fw 0.0.149__py3-none-any.whl → 0.0.151__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,7 @@ from .mdf_sa_ddi.base import MDFSADDIDataset
5
5
  from .embedding_generator import create_embeddings
6
6
  from .idf_helper import IDF
7
7
  from .feature_vector_generation import SimilarityMatrixGenerator, VectorGenerator
8
-
8
+ from .dataset_splitter import DatasetSplitter
9
9
  __all__ = ['BaseDataset','DDIMDLDataset','MDFSADDIDataset']
10
10
 
11
11
 
ddi_fw/datasets/core.py CHANGED
@@ -1,234 +1,129 @@
1
1
  import glob
2
- from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
3
- from sklearn.preprocessing import LabelBinarizer
4
- from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Type
5
3
  import numpy as np
6
4
  import pandas as pd
7
- import pathlib
5
+ from pydantic import BaseModel, Field, computed_field
6
+ from ddi_fw.datasets.dataset_splitter import DatasetSplitter
7
+ from ddi_fw.datasets.feature_vector_generation import SimilarityMatrixGenerator, VectorGenerator
8
8
  from ddi_fw.langchain.embeddings import PoolingStrategy
9
- from ddi_fw.datasets.idf_helper import IDF
10
-
11
- from ddi_fw.utils.zip_helper import ZipHelper
12
- from .feature_vector_generation import SimilarityMatrixGenerator, VectorGenerator
13
- # from ddi_fw.ner.ner import CTakesNER
14
- from ddi_fw.utils import create_folder_if_not_exists
15
- from stopwatch import Stopwatch, profile
16
-
17
- HERE = pathlib.Path(__file__).resolve().parent
9
+ from ddi_fw.utils.utils import create_folder_if_not_exists
18
10
 
19
11
 
20
12
  def stack(df_column):
21
13
  return np.stack(df_column.values)
22
14
 
23
15
 
24
- class BaseDataset(ABC):
25
- def __init__(self,
26
- embedding_size,
27
- embedding_dict,
28
- embeddings_pooling_strategy: PoolingStrategy,
29
- ner_df,
30
- chemical_property_columns,
31
- embedding_columns,
32
- ner_columns,
33
- **kwargs):
34
- self.embedding_size = embedding_size
35
- self.embedding_dict = embedding_dict
36
- self.embeddings_pooling_strategy = embeddings_pooling_strategy
37
- self.ner_df = ner_df
38
- self.__similarity_related_columns__ = []
39
- self.__similarity_related_columns__.extend(chemical_property_columns)
40
- self.__similarity_related_columns__.extend(ner_columns)
41
-
42
- self.chemical_property_columns = chemical_property_columns
43
- self.embedding_columns = embedding_columns
44
- self.ner_columns = ner_columns
45
- self.threshold_method = kwargs.get('threshold_method', 'idf')
46
- self.tui_threshold = kwargs.get('tui_threshold', 0)
47
- self.cui_threshold = kwargs.get('cui_threshold', 0)
48
- self.entities_threshold = kwargs.get('entities_threshold', 0)
49
-
50
- self.stopwatch = Stopwatch()
51
-
52
- # self.store_similarity_matrices = kwargs.get('store_similarity_matrices', True)
53
- # self.similarity_matrices_path = kwargs.get('similarity_matrices_path', True)
54
-
55
- # önce load veya split çalıştırılmalı
16
+ def generate_vectors(df, columns):
17
+ vectorGenerator = VectorGenerator(df)
18
+ generated_vectors = vectorGenerator.generate_feature_vectors(
19
+ columns)
20
+ return generated_vectors
21
+
22
+
23
+ def generate_sim_matrices_new(df, generated_vectors, columns, key_column="id"):
24
+ jaccard_sim_dict = {}
25
+ sim_matrix_gen = SimilarityMatrixGenerator()
26
+
27
+ for column in columns:
28
+ # key = '2D_'+column
29
+ key = column
30
+ jaccard_sim_dict[column] = sim_matrix_gen.create_jaccard_similarity_matrices(
31
+ generated_vectors[key])
32
+
33
+ similarity_matrices = {}
34
+ keys = df[key_column].to_list()
35
+ new_columns = {}
36
+ for idx in range(len(keys)):
37
+ new_columns[idx] = keys[idx]
38
+ for column in columns:
39
+ new_df = pd.DataFrame.from_dict(jaccard_sim_dict[column])
40
+ new_df = new_df.rename(index=new_columns, columns=new_columns)
41
+ similarity_matrices[column] = new_df
42
+ return similarity_matrices
43
+
44
+
45
+ class BaseDataset(BaseModel):
46
+ dataset_name: str
47
+ index_path: str
48
+ dataset_splitter_type: Type[DatasetSplitter]
49
+ class_column: str = 'class'
50
+ dataframe: Optional[pd.DataFrame] = None
51
+ X_train: Optional[pd.DataFrame] = None
52
+ X_test: Optional[pd.DataFrame] = None
53
+ y_train: Optional[pd.Series] = None
54
+ y_test: Optional[pd.Series] = None
55
+ train_indexes: Optional[pd.Index] = None
56
+ test_indexes: Optional[pd.Index] = None
57
+ train_idx_arr: List|None = None
58
+ val_idx_arr: List|None = None
59
+ # train_idx_arr: Optional[List[np.ndarray]] = None
60
+ # val_idx_arr: Optional[List[np.ndarray]] = None
61
+ columns: List[str] = []
62
+
63
+ # feature_process: FeatureProcessor
64
+ # similarity_matrix_service: SimilarityMatrixService
65
+
66
+ class Config:
67
+ arbitrary_types_allowed = True
68
+
56
69
  def produce_inputs(self):
57
70
  items = []
71
+ if self.X_train is None or self.X_test is None:
72
+ raise Exception("There is no data to produce inputs")
58
73
  y_train_label, y_test_label = stack(self.y_train), stack(self.y_test)
59
- # self.__similarity_related_columns__.append("smile_2") #TODO
60
- for column in self.__similarity_related_columns__:
74
+
75
+ for column in self.columns:
61
76
  train_data, test_data = stack(
62
77
  self.X_train[column]), stack(self.X_test[column])
63
78
  items.append([f'{column}', np.nan_to_num(train_data),
64
79
  y_train_label, np.nan_to_num(test_data), y_test_label])
65
- for column in self.embedding_columns:
66
- train_data, test_data = stack(
67
- self.X_train[column+'_embedding']), stack(self.X_test[column+'_embedding'])
68
- items.append([f'{column}_embedding', train_data,
69
- y_train_label, test_data, y_test_label])
80
+
81
+ # items.append([f'{column}_embedding', train_data,
82
+ # y_train_label, test_data, y_test_label])
70
83
  return items
84
+
85
+ @computed_field
86
+ @property
87
+ def dataset_splitter(self) -> DatasetSplitter:
88
+ return self.dataset_splitter_type()
89
+
90
+ def set_dataframe(self, dataframe: pd.DataFrame):
91
+ self.dataframe = dataframe
92
+
93
+ # @abstractmethod
94
+ def prep(self):
95
+ pass
96
+
97
+ def load(self):
98
+ if self.index_path is None:
99
+ raise Exception(
100
+ "There is no index path, please call split function")
101
+
102
+ try:
103
+ train_idx_all, test_idx_all, train_idx_arr, val_idx_arr = self.__get_indexes__(
104
+ self.index_path)
105
+ except FileNotFoundError as e:
106
+ raise FileNotFoundError(f"Index files not found: {e.filename}")
71
107
 
72
- # remove this function
73
- def generate_sim_matrices(self, chemical_properties_df, two_d_dict):
74
-
75
- jaccard_sim_dict = {}
76
- sim_matrix_gen = SimilarityMatrixGenerator()
77
-
78
- for column in self.__similarity_related_columns__:
79
- key = '2D_'+column
80
- jaccard_sim_dict[column] = sim_matrix_gen.create_jaccard_similarity_matrices(
81
- two_d_dict[key])
82
-
83
- drugbank_ids = chemical_properties_df['id'].to_list()
84
-
85
- similarity_matrices = {}
86
-
87
- for column in self.__similarity_related_columns__:
88
- sim_matrix = jaccard_sim_dict[column]
89
- jaccard_sim_feature = {}
90
- for i in range(len(drugbank_ids)):
91
- jaccard_sim_feature[drugbank_ids[i]] = sim_matrix[i]
92
- similarity_matrices[column] = jaccard_sim_feature
93
-
94
- return similarity_matrices
95
-
96
- def generate_sim_matrices_new(self, chemical_properties_df):
97
- self.stopwatch.reset()
98
- self.stopwatch.start()
99
- jaccard_sim_dict = {}
100
- sim_matrix_gen = SimilarityMatrixGenerator()
101
-
102
- for column in self.__similarity_related_columns__:
103
- # key = '2D_'+column
104
- key = column
105
- jaccard_sim_dict[column] = sim_matrix_gen.create_jaccard_similarity_matrices(
106
- self.generated_vectors[key])
107
- self.stopwatch.stop()
108
- print(f'similarity_matrix_generation_part_1: {self.stopwatch.elapsed}')
109
-
110
- self.stopwatch.reset()
111
- self.stopwatch.start()
112
- similarity_matrices = {}
113
- drugbank_ids = chemical_properties_df['id'].to_list()
114
- new_columns = {}
115
- for idx in range(len(drugbank_ids)):
116
- new_columns[idx] = drugbank_ids[idx]
117
- for column in self.__similarity_related_columns__:
118
- new_df = pd.DataFrame.from_dict(jaccard_sim_dict[column])
119
- new_df = new_df.rename(index=new_columns, columns=new_columns)
120
- similarity_matrices[column] = new_df
121
- self.stopwatch.stop()
122
- print(f'similarity_matrix_generation_part_2: {self.stopwatch.elapsed}')
123
- return similarity_matrices
124
-
125
- # matris formuna çevirmek için
126
- def transform_2d(self, chemical_properties_df):
127
- two_d_dict = {}
128
- for column in self.__similarity_related_columns__:
129
- key = '2D_'+column
130
- new_column = column + '_vectors'
131
- two_d_dict[key] = np.stack(
132
- chemical_properties_df[new_column].to_numpy())
133
-
134
- return two_d_dict
135
-
136
- # todo dictionary içinde ndarray dönsün
137
- def generate_vectors(self, chemical_properties_df):
138
- self.stopwatch.reset()
139
- self.stopwatch.start()
140
- vectorGenerator = VectorGenerator(chemical_properties_df)
141
-
142
- new_columns = [
143
- c+'_vectors' for c in self.__similarity_related_columns__]
144
- self.generated_vectors = vectorGenerator.generate_feature_vectors(
145
- self.__similarity_related_columns__)
146
-
147
- # for column, new_column in zip(self.__similarity_related_columns__, new_columns):
148
- # chemical_properties_df.loc[:,
149
- # new_column] = generated_vectors[column]
150
- # self.generated_vectors = generated_vectors
151
- self.stopwatch.stop()
152
- print(f'vector_generation: {self.stopwatch.elapsed}')
153
-
154
-
155
- # remove this function
156
-
157
-
158
- def sim(self, chemical_properties_df):
159
- self.stopwatch.reset()
160
- self.stopwatch.start()
161
- from scipy.spatial.distance import pdist
162
- sim_matrix_gen = SimilarityMatrixGenerator()
163
-
164
- drugbank_ids = chemical_properties_df['id'].to_list()
165
- similarity_matrices = {}
166
- for column in self.__similarity_related_columns__:
167
- df = pd.DataFrame(np.stack(
168
- chemical_properties_df[f'{column}_vectors'].values), index=drugbank_ids)
169
- # similarity_matrices[column] = 1 - pdist(df.to_numpy(), metric='jaccard')
170
- similarity_matrices[column] = sim_matrix_gen.create_jaccard_similarity_matrices(
171
- df.to_numpy())
172
- self.stopwatch.stop()
173
- print(f'sim: {self.stopwatch.elapsed}')
174
- return similarity_matrices
175
-
176
- # import pandas as pd
177
- # a = [[0,0,1],[0,0,1],[0,0,0]]
178
- # s = pd.Series(a)
179
- # # print(np.vstack(s.to_numpy()))
180
- # l = np.argmax(np.vstack(s.to_numpy()),axis = 1)
181
- # l
182
- def split_dataset(self,
183
- fold_size=5,
184
- shuffle=True,
185
- test_size=0.2,
186
- save_indexes=False):
187
- save_path = self.index_path
188
108
  self.prep()
189
- X = self.dataframe.drop('class', axis=1)
190
- y = self.dataframe['class']
191
- X_train, X_test, y_train, y_test = train_test_split(
192
- X, y, shuffle=shuffle, test_size=test_size, stratify=np.argmax(np.vstack(y.to_numpy()), axis=1))
193
- # k_fold = KFold(n_splits=fold_size, shuffle=shuffle, random_state=1)
194
- # folds = k_fold.split(X_train)
195
-
196
- k_fold = StratifiedKFold(
197
- n_splits=fold_size, shuffle=shuffle, random_state=1)
198
- folds = k_fold.split(X_train, np.argmax(
199
- np.vstack(y_train.to_numpy()), axis=1))
200
- train_idx_arr = []
201
- val_idx_arr = []
202
- for i, (train_index, val_index) in enumerate(folds):
203
- train_idx_arr.append(train_index)
204
- val_idx_arr.append(val_index)
205
109
 
206
- if save_indexes:
207
- # train_pairs = [row['id1'].join(',').row['id2'] for index, row in X_train.iterrows()]
208
- self.__save_indexes__(
209
- save_path, 'train_indexes.txt', X_train['index'].values)
210
- self.__save_indexes__(
211
- save_path, 'test_indexes.txt', X_test['index'].values)
212
- # self.__save_indexes__(
213
- # save_path, 'train_indexes.txt', X_train.index.values)
214
- # self.__save_indexes__(
215
- # save_path, 'test_indexes.txt', X_test.index.values)
110
+ if self.dataframe is None:
111
+ raise Exception("There is no dataframe")
216
112
 
217
- for i, (train_idx, val_idx) in enumerate(zip(train_idx_arr, val_idx_arr)):
218
- self.__save_indexes__(
219
- save_path, f'train_fold_{i}.txt', train_idx)
220
- self.__save_indexes__(
221
- save_path, f'validation_fold_{i}.txt', val_idx)
113
+ train = self.dataframe[self.dataframe.index.isin(train_idx_all)]
114
+ test = self.dataframe[self.dataframe.index.isin(test_idx_all)]
222
115
 
223
- self.X_train = X_train
224
- self.X_test = X_test
225
- self.y_train = y_train
226
- self.y_test = y_test
227
- self.train_indexes = X_train.index
228
- self.test_indexes = X_test.index
116
+ self.X_train = train.drop(self.class_column, axis=1)
117
+ self.y_train = train[self.class_column]
118
+ self.X_test = test.drop(self.class_column, axis=1)
119
+ self.y_test = test[self.class_column]
120
+
121
+ self.train_indexes = self.X_train.index
122
+ self.test_indexes = self.X_test.index
229
123
  self.train_idx_arr = train_idx_arr
230
124
  self.val_idx_arr = val_idx_arr
231
- return X_train, X_test, y_train, y_test, X_train.index, X_test.index, train_idx_arr, val_idx_arr
125
+
126
+ return self.X_train, self.X_test, self.y_train, self.y_test, self.X_train.index, self.X_test.index, train_idx_arr, val_idx_arr
232
127
 
233
128
  def __get_indexes__(self, path):
234
129
  train_index_path = path+'/train_indexes.txt'
@@ -259,147 +154,58 @@ class BaseDataset(ABC):
259
154
  with open(file_path, 'w') as f:
260
155
  f.write('\n'.join(str_indexes))
261
156
 
262
- # @abstractmethod
263
- # def prep(self):
264
- # pass
157
+ def split_dataset(self, save_indexes: bool = False):
158
+ # TODO class type should be parametric
265
159
 
266
- # @abstractmethod
267
- # def load(self):
268
- # pass
160
+ save_path = self.index_path
161
+ self.prep()
269
162
 
270
- # her bir metin tipi için embedding oluşturursan burayı düzenle
271
- def prep(self):
272
- drug_names = self.drugs_df['name'].to_list()
273
- drug_ids = self.drugs_df['id'].to_list()
274
-
275
- filtered_df = self.drugs_df
276
- combined_df = filtered_df.copy()
277
-
278
- if self.ner_df is not None and not self.ner_df.empty:
279
- filtered_ner_df = self.ner_df[self.ner_df['drugbank_id'].isin(
280
- drug_ids)]
281
- filtered_ner_df = self.ner_df.copy()
282
-
283
- # TODO: eğer kullanılan veri setinde tui, cui veya entity bilgileri yoksa o veri setine bu sütunları eklemek için aşağısı gerekli
284
-
285
- # idf_calc = IDF(filtered_ner_df, [f for f in filtered_ner_df.keys()])
286
- idf_calc = IDF(filtered_ner_df, self.ner_columns)
287
- idf_calc.calculate()
288
- idf_scores_df = idf_calc.to_dataframe()
289
-
290
- # for key in filtered_ner_df.keys():
291
- for key in self.ner_columns:
292
- threshold = 0
293
- if key.startswith('tui'):
294
- threshold = self.tui_threshold
295
- if key.startswith('cui'):
296
- threshold = self.cui_threshold
297
- if key.startswith('entities'):
298
- threshold = self.entities_threshold
299
- combined_df[key] = filtered_ner_df[key]
300
- valid_codes = idf_scores_df[idf_scores_df[key] > threshold].index
301
-
302
- # print(f'{key}: valid code size = {len(valid_codes)}')
303
- combined_df[key] = combined_df[key].apply(lambda items:
304
- [item for item in items if item in valid_codes])
305
-
306
- moved_columns = ['id']
307
- moved_columns.extend(self.__similarity_related_columns__)
308
- chemical_properties_df = combined_df[moved_columns]
309
-
310
- chemical_properties_df = chemical_properties_df.fillna("").apply(list)
311
-
312
- # generate vectors dictionary içinde ndarray dönecek
313
- self.generate_vectors(chemical_properties_df)
314
-
315
- # two_d_dict = self.transform_2d(chemical_properties_df)
316
-
317
- similarity_matrices = self.generate_sim_matrices_new(
318
- chemical_properties_df)
319
-
320
- # similarity_matrices = self.sim(chemical_properties_df)
321
-
322
- event_categories = self.ddis_df['event_category']
323
- labels = event_categories.tolist()
324
- lb = LabelBinarizer()
325
- lb.fit(labels)
326
- classes = lb.transform(labels)
327
-
328
- # def similarity_lambda_fnc(row, value):
329
- # if row['id1'] in value and row['id2'] in value:
330
- # return value[row['id1']][row['id2']]
331
-
332
- def similarity_lambda_fnc(row, value):
333
- if row['id1'] in value:
334
- return value[row['id1']]
335
-
336
- def lambda_fnc(row, value):
337
- if row['id1'] in value and row['id2'] in value:
338
- return np.float16(np.hstack(
339
- (value[row['id1']], value[row['id2']])))
340
- # return np.hstack(
341
- # (value[row['id1']], value[row['id2']]), dtype=np.float16)
342
-
343
- def x_fnc(row, embeddings_after_pooling):
344
- if row['id1'] in embeddings_after_pooling:
345
- v1 = embeddings_after_pooling[row['id1']]
346
- else:
347
- v1 = np.zeros(self.embedding_size)
348
- if row['id2'] in embeddings_after_pooling:
349
- v2 = embeddings_after_pooling[row['id2']]
350
- else:
351
- v2 = np.zeros(self.embedding_size)
352
- return np.float16(np.hstack(
353
- (v1, v2)))
354
-
355
- for key, value in similarity_matrices.items():
356
-
357
- print(f'sim matrix: {key}')
358
- self.ddis_df[key] = self.ddis_df.apply(
359
- lambda_fnc, args=(value,), axis=1)
360
- print(self.ddis_df[key].head())
361
-
362
- for embedding_column in self.embedding_columns:
363
- print(f"concat {embedding_column} embeddings")
364
- embeddings_after_pooling = {k: self.embeddings_pooling_strategy.apply(
365
- v) for k, v in self.embedding_dict[embedding_column].items()}
366
- # column_embeddings_dict = embedding_values[embedding_column]
367
- self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
368
- x_fnc, args=(embeddings_after_pooling,), axis=1)
369
-
370
- self.dataframe = self.ddis_df.copy()
371
- self.dataframe['class'] = list(classes)
372
- print(self.dataframe.shape)
163
+ if self.dataframe is None:
164
+ raise Exception("There is no data")
373
165
 
374
- def load(self):
375
- if self.index_path == None:
376
- raise Exception(
377
- "There is no index path, please call split function")
166
+ X = self.dataframe.drop(self.class_column, axis=1)
167
+ y = self.dataframe[self.class_column]
378
168
 
379
- # prep - split - load
380
- train_idx_all, test_idx_all, train_idx_arr, val_idx_arr = self.__get_indexes__(
381
- self.index_path)
169
+ X_train, X_test, y_train, y_test, X_train.index, X_test.index, train_idx_arr, val_idx_arr = self.dataset_splitter.split(
170
+ X=X, y=y)
171
+ self.X_train = X_train
172
+ self.X_test = X_test
173
+ self.y_train = y_train
174
+ self.y_test = y_test
175
+ self.train_indexes = X_train.index
176
+ self.test_indexes = X_test.index
177
+ self.train_idx_arr = train_idx_arr
178
+ self.val_idx_arr = val_idx_arr
382
179
 
383
- self.prep()
384
- train = self.dataframe[self.dataframe['index'].isin(train_idx_all)]
385
- test = self.dataframe[self.dataframe['index'].isin(test_idx_all)]
180
+ if save_indexes:
181
+ # train_pairs = [row['id1'].join(',').row['id2'] for index, row in X_train.iterrows()]
182
+ self.__save_indexes__(
183
+ save_path, 'train_indexes.txt', self.train_indexes.values)
184
+ self.__save_indexes__(
185
+ save_path, 'test_indexes.txt', self.test_indexes.values)
186
+
187
+ for i, (train_idx, val_idx) in enumerate(zip(train_idx_arr, val_idx_arr)):
188
+ self.__save_indexes__(
189
+ save_path, f'train_fold_{i}.txt', train_idx)
190
+ self.__save_indexes__(
191
+ save_path, f'validation_fold_{i}.txt', val_idx)
386
192
 
387
- self.X_train = train.drop('class', axis=1)
388
- self.y_train = train['class']
389
- self.X_test = test.drop('class', axis=1)
390
- self.y_test = test['class']
193
+ # return X_train, X_test, y_train, y_test, folds
391
194
 
392
- self.train_indexes = self.X_train.index
393
- self.test_indexes = self.X_test.index
394
- self.train_idx_arr = train_idx_arr
395
- self.val_idx_arr = val_idx_arr
396
195
 
397
- return self.X_train, self.X_test, self.y_train, self.y_test, self.X_train.index, self.X_test.index, train_idx_arr, val_idx_arr
196
+ class TextDatasetMixin(BaseDataset):
197
+ embedding_size: int
198
+ embedding_dict: dict
199
+ embeddings_pooling_strategy: PoolingStrategy | None = None
200
+
201
+ def process_text(self):
202
+ pass
203
+
204
+
205
+ # class ImageDatasetMixin(BaseModel):
206
+ # image_size: tuple[int, int] = Field(default=(224, 224))
207
+ # augmentations: list[str] = Field(default_factory=list)
398
208
 
399
- def export_as_csv(self, output_file_path, not_change: list):
400
- copy = self.dataframe.copy()
401
- for col in copy.columns:
402
- if col not in not_change:
403
- copy[col] = [
404
- '[' + ','.join(f"{value:.3f}" for value in row) + ']' for row in copy[col]]
405
- copy.to_csv(output_file_path, index=False)
209
+ # def process_image_data(self):
210
+ # print(
211
+ # f"Processing image data with size {self.image_size} and augmentations {self.augmentations}...")
@@ -0,0 +1,39 @@
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ import pandas as pd
4
+ from pydantic import BaseModel, Field
5
+ from sklearn.model_selection import StratifiedKFold, train_test_split
6
+
7
+
8
+ class DatasetSplitter(BaseModel):
9
+ fold_size: int = Field(default=5, ge=2)
10
+ test_size: float = Field(default=0.2, ge=0.0, le=1.0)
11
+ shuffle: bool = True
12
+ random_state: int = Field(default=42)
13
+
14
+ class Config:
15
+ arbitrary_types_allowed = True
16
+
17
+ def split(self, X: pd.DataFrame, y: pd.Series)-> Tuple[
18
+ pd.DataFrame, pd.DataFrame, pd.Series, pd.Series, pd.Index, pd.Index, List[np.ndarray], List[np.ndarray]]:
19
+ print(
20
+ f"Splitting dataset into {self.fold_size} folds with shuffle={self.shuffle}...")
21
+ #TODO check it
22
+ if len(y.shape) == 1:
23
+ y = pd.Series(np.expand_dims(y.to_numpy(), axis=1).flatten())
24
+ stacked = np.vstack(tuple(y.to_numpy()))
25
+ stratify = np.argmax(stacked, axis=1)
26
+ X_train, X_test, y_train, y_test = train_test_split(
27
+ X, y, shuffle=self.shuffle, test_size=self.test_size, stratify=stratify)
28
+
29
+ k_fold = StratifiedKFold(
30
+ n_splits=self.fold_size, shuffle=self.shuffle, random_state=self.random_state)
31
+ folds = k_fold.split(X_train, np.argmax(
32
+ np.vstack(y_train.to_numpy()), axis=1))
33
+ train_idx_arr = []
34
+ val_idx_arr = []
35
+ for i, (train_index, val_index) in enumerate(folds):
36
+ train_idx_arr.append(train_index)
37
+ val_idx_arr.append(val_index)
38
+
39
+ return X_train, X_test, y_train, y_test, X_train.index, X_test.index, train_idx_arr, val_idx_arr