upgini 1.1.280.dev0__py3-none-any.whl → 1.2.31a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of upgini might be problematic. Click here for more details.
- upgini/__about__.py +1 -1
- upgini/__init__.py +4 -20
- upgini/autofe/all_operands.py +39 -9
- upgini/autofe/binary.py +148 -45
- upgini/autofe/date.py +197 -26
- upgini/autofe/feature.py +102 -19
- upgini/autofe/groupby.py +22 -22
- upgini/autofe/operand.py +9 -6
- upgini/autofe/unary.py +83 -41
- upgini/autofe/vector.py +8 -8
- upgini/data_source/data_source_publisher.py +128 -5
- upgini/dataset.py +50 -386
- upgini/features_enricher.py +931 -542
- upgini/http.py +27 -16
- upgini/lazy_import.py +35 -0
- upgini/metadata.py +84 -59
- upgini/metrics.py +164 -34
- upgini/normalizer/normalize_utils.py +197 -0
- upgini/resource_bundle/strings.properties +66 -51
- upgini/search_task.py +10 -4
- upgini/utils/Roboto-Regular.ttf +0 -0
- upgini/utils/base_search_key_detector.py +14 -12
- upgini/utils/country_utils.py +16 -0
- upgini/utils/custom_loss_utils.py +39 -36
- upgini/utils/datetime_utils.py +98 -45
- upgini/utils/deduplicate_utils.py +135 -112
- upgini/utils/display_utils.py +46 -15
- upgini/utils/email_utils.py +54 -16
- upgini/utils/feature_info.py +172 -0
- upgini/utils/features_validator.py +34 -20
- upgini/utils/ip_utils.py +100 -1
- upgini/utils/phone_utils.py +343 -0
- upgini/utils/postal_code_utils.py +34 -0
- upgini/utils/sklearn_ext.py +28 -19
- upgini/utils/target_utils.py +113 -57
- upgini/utils/warning_counter.py +1 -0
- upgini/version_validator.py +8 -4
- {upgini-1.1.280.dev0.dist-info → upgini-1.2.31a2.dist-info}/METADATA +31 -16
- upgini-1.2.31a2.dist-info/RECORD +65 -0
- upgini/normalizer/phone_normalizer.py +0 -340
- upgini-1.1.280.dev0.dist-info/RECORD +0 -62
- {upgini-1.1.280.dev0.dist-info → upgini-1.2.31a2.dist-info}/WHEEL +0 -0
- {upgini-1.1.280.dev0.dist-info → upgini-1.2.31a2.dist-info}/licenses/LICENSE +0 -0
upgini/utils/phone_utils.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
1
4
|
import pandas as pd
|
|
5
|
+
from pandas.api.types import is_float_dtype, is_object_dtype, is_string_dtype
|
|
2
6
|
|
|
7
|
+
from upgini.errors import ValidationError
|
|
3
8
|
from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
|
|
4
9
|
|
|
5
10
|
|
|
@@ -9,3 +14,341 @@ class PhoneSearchKeyDetector(BaseSearchKeyDetector):
|
|
|
9
14
|
|
|
10
15
|
def _is_search_key_by_values(self, column: pd.Series) -> bool:
|
|
11
16
|
return False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PhoneSearchKeyConverter:
|
|
20
|
+
|
|
21
|
+
def __init__(self, phone_column: str, country_column: Optional[str] = None):
|
|
22
|
+
self.phone_column = phone_column
|
|
23
|
+
self.country_column = country_column
|
|
24
|
+
|
|
25
|
+
def convert(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
26
|
+
df = self.phone_to_int(df)
|
|
27
|
+
if self.country_column is not None:
|
|
28
|
+
df[self.phone_column] = df.apply(self.add_prefix, axis=1)
|
|
29
|
+
df[self.phone_column] = df[self.phone_column].astype("Int64")
|
|
30
|
+
return df
|
|
31
|
+
|
|
32
|
+
def add_prefix(self, row):
|
|
33
|
+
phone = row[self.phone_column]
|
|
34
|
+
if pd.isna(phone):
|
|
35
|
+
return phone
|
|
36
|
+
country = row[self.country_column]
|
|
37
|
+
country_prefix_tuple = self.COUNTRIES_PREFIXES.get(country)
|
|
38
|
+
if country_prefix_tuple is not None:
|
|
39
|
+
country_prefix, number_of_digits = country_prefix_tuple
|
|
40
|
+
if len(str(phone)) == number_of_digits:
|
|
41
|
+
return int(country_prefix + str(phone))
|
|
42
|
+
return phone
|
|
43
|
+
|
|
44
|
+
def phone_to_int(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
45
|
+
"""
|
|
46
|
+
Convention: phone number is always presented as int number.
|
|
47
|
+
phone_number = Country code + National Destination Code + Subscriber Number.
|
|
48
|
+
Examples:
|
|
49
|
+
41793834315 for Switzerland
|
|
50
|
+
46767040672 for Sweden
|
|
51
|
+
861065529988 for China
|
|
52
|
+
18143008198 for the USA
|
|
53
|
+
Inplace conversion of phone to int.
|
|
54
|
+
|
|
55
|
+
Method will remove all non numeric chars from string and convert it to int.
|
|
56
|
+
None will be set for phone numbers that couldn"t be converted to int
|
|
57
|
+
"""
|
|
58
|
+
if is_string_dtype(df[self.phone_column]) or is_object_dtype(df[self.phone_column]):
|
|
59
|
+
convert_func = self.phone_str_to_int_safe
|
|
60
|
+
elif is_float_dtype(df[self.phone_column]):
|
|
61
|
+
convert_func = self.phone_float_to_int_safe
|
|
62
|
+
elif df[self.phone_column].dtype == np.int64 or isinstance(
|
|
63
|
+
df[self.phone_column].dtype, pd.Int64Dtype
|
|
64
|
+
):
|
|
65
|
+
convert_func = self.phone_int_to_int_safe
|
|
66
|
+
else:
|
|
67
|
+
raise ValidationError(
|
|
68
|
+
f"phone_column_name {self.phone_column} doesn't have supported dtype. "
|
|
69
|
+
f"Dataset dtypes: {df.dtypes}. "
|
|
70
|
+
f"Contact developer and request to implement conversion of {self.phone_column} to int"
|
|
71
|
+
)
|
|
72
|
+
df[self.phone_column] = df[self.phone_column].apply(convert_func).astype("Int64")
|
|
73
|
+
return df
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def phone_float_to_int_safe(value: float) -> Optional[int]:
|
|
77
|
+
try:
|
|
78
|
+
return PhoneSearchKeyConverter.validate_length(int(value))
|
|
79
|
+
except Exception:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def phone_int_to_int_safe(value: int) -> Optional[int]:
|
|
84
|
+
try:
|
|
85
|
+
return PhoneSearchKeyConverter.validate_length(int(value))
|
|
86
|
+
except Exception:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def phone_str_to_int_safe(value: str) -> Optional[int]:
|
|
91
|
+
try:
|
|
92
|
+
value = str(value)
|
|
93
|
+
if value.endswith(".0"):
|
|
94
|
+
value = value[: len(value) - 2]
|
|
95
|
+
numeric_filter = filter(str.isdigit, value)
|
|
96
|
+
numeric_string = "".join(numeric_filter)
|
|
97
|
+
return PhoneSearchKeyConverter.validate_length(int(numeric_string))
|
|
98
|
+
except Exception:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def validate_length(value: int) -> Optional[int]:
|
|
103
|
+
if value < 10000000 or value > 999999999999999:
|
|
104
|
+
return None
|
|
105
|
+
else:
|
|
106
|
+
return value
|
|
107
|
+
|
|
108
|
+
COUNTRIES_PREFIXES = {
|
|
109
|
+
"US": ("1", 10),
|
|
110
|
+
"CA": ("1", 10),
|
|
111
|
+
"AI": ("1", 10),
|
|
112
|
+
"AG": ("1", 10),
|
|
113
|
+
"AS": ("1", 10),
|
|
114
|
+
"BB": ("1", 10),
|
|
115
|
+
"BS": ("1", 10),
|
|
116
|
+
"VG": ("1", 10),
|
|
117
|
+
"VI": ("1", 10),
|
|
118
|
+
"KY": ("1", 10),
|
|
119
|
+
"BM": ("1", 10),
|
|
120
|
+
"GD": ("1", 10),
|
|
121
|
+
"TC": ("1", 10),
|
|
122
|
+
"MS": ("1", 10),
|
|
123
|
+
"MP": ("1", 10),
|
|
124
|
+
"GU": ("1", 10),
|
|
125
|
+
"SX": ("1", 10),
|
|
126
|
+
"LC": ("1", 10),
|
|
127
|
+
"DM": ("1", 10),
|
|
128
|
+
"VC": ("1", 10),
|
|
129
|
+
"PR": ("1", 10),
|
|
130
|
+
"TT": ("1", 10),
|
|
131
|
+
"KN": ("1", 10),
|
|
132
|
+
"JM": ("1", 10),
|
|
133
|
+
"EG": ("20", 9),
|
|
134
|
+
"SS": ("211", 9),
|
|
135
|
+
"MA": ("212", 9),
|
|
136
|
+
"EH": ("212", 4),
|
|
137
|
+
"DZ": ("213", 8),
|
|
138
|
+
"TN": ("216", 8),
|
|
139
|
+
"LY": ("218", 9),
|
|
140
|
+
"GM": ("220", 6),
|
|
141
|
+
"SN": ("221", 9),
|
|
142
|
+
"MR": ("222", 7),
|
|
143
|
+
"ML": ("223", 8),
|
|
144
|
+
"GN": ("224", 9),
|
|
145
|
+
"CI": ("225", 7),
|
|
146
|
+
"BF": ("226", 8),
|
|
147
|
+
"NE": ("227", 8),
|
|
148
|
+
"TG": ("228", 8),
|
|
149
|
+
"BJ": ("229", 8),
|
|
150
|
+
"MU": ("230", 7),
|
|
151
|
+
"LR": ("231", 9),
|
|
152
|
+
"SL": ("232", 8),
|
|
153
|
+
"GH": ("233", 9),
|
|
154
|
+
"NG": ("234", 9),
|
|
155
|
+
"TD": ("235", 8),
|
|
156
|
+
"CF": ("236", 7),
|
|
157
|
+
"CM": ("237", 9),
|
|
158
|
+
"CV": ("238", 7),
|
|
159
|
+
"ST": ("239", 7),
|
|
160
|
+
"GQ": ("240", 9),
|
|
161
|
+
"GA": ("241", 8),
|
|
162
|
+
"CG": ("242", 7),
|
|
163
|
+
"CD": ("243", 9),
|
|
164
|
+
"AO": ("244", 9),
|
|
165
|
+
"GW": ("245", 6),
|
|
166
|
+
"IO": ("246", 7),
|
|
167
|
+
"AC": ("247", 5),
|
|
168
|
+
"SC": ("248", 7),
|
|
169
|
+
"SD": ("249", 9),
|
|
170
|
+
"RW": ("250", 9),
|
|
171
|
+
"ET": ("251", 9),
|
|
172
|
+
"SO": ("252", 9),
|
|
173
|
+
"DJ": ("253", 8),
|
|
174
|
+
"KE": ("254", 9),
|
|
175
|
+
"TZ": ("255", 9),
|
|
176
|
+
"UG": ("256", 9),
|
|
177
|
+
"BI": ("257", 8),
|
|
178
|
+
"MZ": ("258", 8),
|
|
179
|
+
"ZM": ("260", 9),
|
|
180
|
+
"MG": ("261", 9),
|
|
181
|
+
"RE": ("262", 9),
|
|
182
|
+
"YT": ("262", 9),
|
|
183
|
+
"TF": ("262", 9),
|
|
184
|
+
"ZW": ("263", 9),
|
|
185
|
+
"NA": ("264", 9),
|
|
186
|
+
"MW": ("265", 7),
|
|
187
|
+
"LS": ("266", 8),
|
|
188
|
+
"BW": ("267", 7),
|
|
189
|
+
"SZ": ("268", 8),
|
|
190
|
+
"KM": ("269", 7),
|
|
191
|
+
"ZA": ("27", 10),
|
|
192
|
+
"SH": ("290", 5),
|
|
193
|
+
"TA": ("290", 5),
|
|
194
|
+
"ER": ("291", 7),
|
|
195
|
+
"AT": ("43", 10),
|
|
196
|
+
"AW": ("297", 7),
|
|
197
|
+
"FO": ("298", 6),
|
|
198
|
+
"GL": ("299", 6),
|
|
199
|
+
"GR": ("30", 10),
|
|
200
|
+
"BE": ("32", 8),
|
|
201
|
+
"FR": ("33", 9),
|
|
202
|
+
"ES": ("34", 9),
|
|
203
|
+
"GI": ("350", 8),
|
|
204
|
+
"PE": ("51", 8),
|
|
205
|
+
"MX": ("52", 10),
|
|
206
|
+
"CU": ("53", 8),
|
|
207
|
+
"AR": ("54", 10),
|
|
208
|
+
"BR": ("55", 10),
|
|
209
|
+
"CL": ("56", 9),
|
|
210
|
+
"CO": ("57", 8),
|
|
211
|
+
"VE": ("58", 10),
|
|
212
|
+
"PT": ("351", 9),
|
|
213
|
+
"LU": ("352", 8),
|
|
214
|
+
"IE": ("353", 8),
|
|
215
|
+
"IS": ("354", 7),
|
|
216
|
+
"AL": ("355", 8),
|
|
217
|
+
"MT": ("356", 8),
|
|
218
|
+
"CY": ("357", 8),
|
|
219
|
+
"FI": ("358", 9),
|
|
220
|
+
"BG": ("359", 8),
|
|
221
|
+
"HU": ("36", 8),
|
|
222
|
+
"LT": ("370", 8),
|
|
223
|
+
"LV": ("371", 8),
|
|
224
|
+
"EE": ("372", 7),
|
|
225
|
+
"MD": ("373", 8),
|
|
226
|
+
"AM": ("374", 8),
|
|
227
|
+
"BY": ("375", 9),
|
|
228
|
+
"AD": ("376", 6),
|
|
229
|
+
"MC": ("377", 8),
|
|
230
|
+
"SM": ("378", 9),
|
|
231
|
+
"VA": ("3906698", 5),
|
|
232
|
+
"UA": ("380", 9),
|
|
233
|
+
"RS": ("381", 9),
|
|
234
|
+
"ME": ("382", 8),
|
|
235
|
+
"HR": ("385", 8),
|
|
236
|
+
"SI": ("386", 8),
|
|
237
|
+
"BA": ("387", 8),
|
|
238
|
+
"MK": ("389", 8),
|
|
239
|
+
"MY": ("60", 9),
|
|
240
|
+
"AU": ("61", 9),
|
|
241
|
+
"CX": ("61", 9),
|
|
242
|
+
"CC": ("61", 9),
|
|
243
|
+
"ID": ("62", 9),
|
|
244
|
+
"PH": ("632", 7),
|
|
245
|
+
"NZ": ("64", 8),
|
|
246
|
+
"PN": ("64", 8),
|
|
247
|
+
"SG": ("65", 8),
|
|
248
|
+
"TH": ("66", 8),
|
|
249
|
+
"IT": ("39", 10),
|
|
250
|
+
"RO": ("40", 9),
|
|
251
|
+
"CH": ("41", 9),
|
|
252
|
+
"CZ": ("420", 9),
|
|
253
|
+
"SK": ("421", 9),
|
|
254
|
+
"GB": ("44", 10),
|
|
255
|
+
"LI": ("423", 7),
|
|
256
|
+
"GG": ("44", 10),
|
|
257
|
+
"IM": ("44", 10),
|
|
258
|
+
"JE": ("44", 10),
|
|
259
|
+
"DK": ("45", 8),
|
|
260
|
+
"SE": ("46", 8),
|
|
261
|
+
"BD": ("880", 8),
|
|
262
|
+
"TW": ("886", 9),
|
|
263
|
+
"JP": ("81", 9),
|
|
264
|
+
"KR": ("82", 9),
|
|
265
|
+
"VN": ("84", 10),
|
|
266
|
+
"KP": ("850", 8),
|
|
267
|
+
"HK": ("852", 8),
|
|
268
|
+
"MO": ("853", 8),
|
|
269
|
+
"KH": ("855", 8),
|
|
270
|
+
"LA": ("856", 8),
|
|
271
|
+
"NO": ("47", 8),
|
|
272
|
+
"SJ": ("47", 8),
|
|
273
|
+
"BV": ("47", 8),
|
|
274
|
+
"PL": ("48", 9),
|
|
275
|
+
"DE": ("49", 10),
|
|
276
|
+
"TR": ("90", 10),
|
|
277
|
+
"IN": ("91", 10),
|
|
278
|
+
"PK": ("92", 9),
|
|
279
|
+
"AF": ("93", 9),
|
|
280
|
+
"LK": ("94", 9),
|
|
281
|
+
"MM": ("95", 7),
|
|
282
|
+
"IR": ("98", 10),
|
|
283
|
+
"MV": ("960", 7),
|
|
284
|
+
"LB": ("961", 7),
|
|
285
|
+
"JO": ("962", 9),
|
|
286
|
+
"SY": ("963", 10),
|
|
287
|
+
"IQ": ("964", 10),
|
|
288
|
+
"KW": ("965", 7),
|
|
289
|
+
"SA": ("966", 9),
|
|
290
|
+
"YE": ("967", 7),
|
|
291
|
+
"OM": ("968", 8),
|
|
292
|
+
"PS": ("970", 8),
|
|
293
|
+
"AE": ("971", 8),
|
|
294
|
+
"IL": ("972", 9),
|
|
295
|
+
"BH": ("973", 8),
|
|
296
|
+
"QA": ("974", 8),
|
|
297
|
+
"BT": ("975", 7),
|
|
298
|
+
"MN": ("976", 8),
|
|
299
|
+
"NP": ("977", 8),
|
|
300
|
+
"TJ": ("992", 9),
|
|
301
|
+
"TM": ("993", 8),
|
|
302
|
+
"AZ": ("994", 9),
|
|
303
|
+
"GE": ("995", 9),
|
|
304
|
+
"KG": ("996", 9),
|
|
305
|
+
"UZ": ("998", 9),
|
|
306
|
+
"FK": ("500", 5),
|
|
307
|
+
"BZ": ("501", 7),
|
|
308
|
+
"GT": ("502", 8),
|
|
309
|
+
"SV": ("503", 8),
|
|
310
|
+
"HN": ("504", 8),
|
|
311
|
+
"NI": ("505", 8),
|
|
312
|
+
"CR": ("506", 8),
|
|
313
|
+
"PA": ("507", 7),
|
|
314
|
+
"PM": ("508", 6),
|
|
315
|
+
"HT": ("509", 8),
|
|
316
|
+
"GS": ("500", 5),
|
|
317
|
+
"MF": ("590", 9),
|
|
318
|
+
"BL": ("590", 9),
|
|
319
|
+
"GP": ("590", 9),
|
|
320
|
+
"BO": ("591", 9),
|
|
321
|
+
"GY": ("592", 9),
|
|
322
|
+
"EC": ("593", 9),
|
|
323
|
+
"GF": ("594", 9),
|
|
324
|
+
"PY": ("595", 9),
|
|
325
|
+
"MQ": ("596", 9),
|
|
326
|
+
"SR": ("597", 9),
|
|
327
|
+
"UY": ("598", 9),
|
|
328
|
+
"CW": ("599", 9),
|
|
329
|
+
"BQ": ("599", 9),
|
|
330
|
+
"RU": ("7", 10),
|
|
331
|
+
"KZ": ("7", 10),
|
|
332
|
+
"TL": ("670", 7),
|
|
333
|
+
"NF": ("672", 7),
|
|
334
|
+
"HM": ("672", 7),
|
|
335
|
+
"BN": ("673", 7),
|
|
336
|
+
"NR": ("674", 7),
|
|
337
|
+
"PG": ("675", 7),
|
|
338
|
+
"TO": ("676", 7),
|
|
339
|
+
"SB": ("677", 7),
|
|
340
|
+
"VU": ("678", 7),
|
|
341
|
+
"FJ": ("679", 7),
|
|
342
|
+
"PW": ("680", 7),
|
|
343
|
+
"WF": ("681", 7),
|
|
344
|
+
"CK": ("682", 5),
|
|
345
|
+
"NU": ("683", 7),
|
|
346
|
+
"WS": ("685", 7),
|
|
347
|
+
"KI": ("686", 7),
|
|
348
|
+
"NC": ("687", 7),
|
|
349
|
+
"TV": ("688", 7),
|
|
350
|
+
"PF": ("689", 7),
|
|
351
|
+
"TK": ("690", 7),
|
|
352
|
+
"FM": ("691", 7),
|
|
353
|
+
"MH": ("692", 7),
|
|
354
|
+
}
|
|
@@ -1,4 +1,9 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
|
+
from pandas.api.types import (
|
|
3
|
+
is_float_dtype,
|
|
4
|
+
is_object_dtype,
|
|
5
|
+
is_string_dtype,
|
|
6
|
+
)
|
|
2
7
|
|
|
3
8
|
from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
|
|
4
9
|
|
|
@@ -9,3 +14,32 @@ class PostalCodeSearchKeyDetector(BaseSearchKeyDetector):
|
|
|
9
14
|
|
|
10
15
|
def _is_search_key_by_values(self, column: pd.Series) -> bool:
|
|
11
16
|
return False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PostalCodeSearchKeyConverter:
|
|
20
|
+
|
|
21
|
+
def __init__(self, postal_code_column: str):
|
|
22
|
+
self.postal_code_column = postal_code_column
|
|
23
|
+
|
|
24
|
+
def convert(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
25
|
+
if is_string_dtype(df[self.postal_code_column]) or is_object_dtype(df[self.postal_code_column]):
|
|
26
|
+
try:
|
|
27
|
+
df[self.postal_code_column] = (
|
|
28
|
+
df[self.postal_code_column].astype("string").astype("float64").astype("Int64").astype("string")
|
|
29
|
+
)
|
|
30
|
+
except Exception:
|
|
31
|
+
pass
|
|
32
|
+
elif is_float_dtype(df[self.postal_code_column]):
|
|
33
|
+
df[self.postal_code_column] = df[self.postal_code_column].astype("Int64").astype("string")
|
|
34
|
+
|
|
35
|
+
df[self.postal_code_column] = (
|
|
36
|
+
df[self.postal_code_column]
|
|
37
|
+
.astype("string")
|
|
38
|
+
.str.upper()
|
|
39
|
+
.str.replace(r"[^0-9A-Z]", "", regex=True) # remove non alphanumeric characters
|
|
40
|
+
.str.replace(r"^0+\B", "", regex=True) # remove leading zeros
|
|
41
|
+
)
|
|
42
|
+
# if (df[self.postal_code_column] == "").all():
|
|
43
|
+
# raise ValidationError(self.bundle.get("invalid_postal_code").format(self.postal_code_column))
|
|
44
|
+
|
|
45
|
+
return df
|
upgini/utils/sklearn_ext.py
CHANGED
|
@@ -17,7 +17,7 @@ from sklearn.base import clone, is_classifier
|
|
|
17
17
|
from sklearn.exceptions import FitFailedWarning, NotFittedError
|
|
18
18
|
from sklearn.metrics import check_scoring
|
|
19
19
|
from sklearn.metrics._scorer import _MultimetricScorer
|
|
20
|
-
from sklearn.model_selection import check_cv
|
|
20
|
+
from sklearn.model_selection import StratifiedKFold, check_cv
|
|
21
21
|
from sklearn.utils.fixes import np_version, parse_version
|
|
22
22
|
from sklearn.utils.validation import indexable
|
|
23
23
|
|
|
@@ -312,25 +312,34 @@ def cross_validate(
|
|
|
312
312
|
ret[key] = train_scores_dict[name]
|
|
313
313
|
|
|
314
314
|
return ret
|
|
315
|
-
except
|
|
315
|
+
except ValueError as e:
|
|
316
316
|
# logging.exception("Failed to execute overriden cross_validate. Fallback to original")
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
317
|
+
if hasattr(e, "args") and len(e.args) > 0 and "Only one class present in y_true" in e.args[0]:
|
|
318
|
+
# Try change CV to StratifiedKFold and retry
|
|
319
|
+
if hasattr(cv, "shuffle"):
|
|
320
|
+
shuffle = cv.shuffle
|
|
321
|
+
else:
|
|
322
|
+
shuffle = False
|
|
323
|
+
if hasattr(cv, "random_state") and shuffle:
|
|
324
|
+
random_state = cv.random_state
|
|
325
|
+
else:
|
|
326
|
+
random_state = None
|
|
327
|
+
return cross_validate(
|
|
328
|
+
estimator,
|
|
329
|
+
x,
|
|
330
|
+
y,
|
|
331
|
+
groups=groups,
|
|
332
|
+
scoring=scoring,
|
|
333
|
+
cv=StratifiedKFold(n_splits=cv.get_n_splits(), shuffle=shuffle, random_state=random_state),
|
|
334
|
+
n_jobs=n_jobs,
|
|
335
|
+
verbose=verbose,
|
|
336
|
+
fit_params=fit_params,
|
|
337
|
+
pre_dispatch=pre_dispatch,
|
|
338
|
+
return_train_score=return_train_score,
|
|
339
|
+
return_estimator=return_estimator,
|
|
340
|
+
error_score=error_score,
|
|
341
|
+
)
|
|
342
|
+
raise e
|
|
334
343
|
|
|
335
344
|
|
|
336
345
|
def _fit_and_score(
|