upgini 1.1.312a5__py3-none-any.whl → 1.1.313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

upgini/dataset.py CHANGED
@@ -1,24 +1,32 @@
1
1
  import csv
2
+ import hashlib
2
3
  import logging
3
4
  import tempfile
4
5
  import time
6
+ from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
5
7
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
9
 
8
10
  import numpy as np
9
11
  import pandas as pd
12
+ from pandas.api.types import is_bool_dtype as is_bool
13
+ from pandas.api.types import is_datetime64_any_dtype as is_datetime
10
14
  from pandas.api.types import (
11
15
  is_float_dtype,
12
16
  is_integer_dtype,
13
17
  is_numeric_dtype,
14
18
  is_object_dtype,
19
+ is_period_dtype,
15
20
  is_string_dtype,
16
21
  )
17
22
 
18
23
  from upgini.errors import ValidationError
19
24
  from upgini.http import ProgressStage, SearchProgress, _RestClient
20
25
  from upgini.metadata import (
26
+ ENTITY_SYSTEM_RECORD_ID,
21
27
  EVAL_SET_INDEX,
28
+ SEARCH_KEY_UNNEST,
29
+ SYSTEM_COLUMNS,
22
30
  SYSTEM_RECORD_ID,
23
31
  TARGET,
24
32
  DataType,
@@ -32,8 +40,10 @@ from upgini.metadata import (
32
40
  RuntimeParameters,
33
41
  SearchCustomization,
34
42
  )
43
+ from upgini.normalizer.phone_normalizer import PhoneNormalizer
35
44
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
36
45
  from upgini.search_task import SearchTask
46
+ from upgini.utils import combine_search_keys, find_numbers_with_decimal_comma
37
47
  from upgini.utils.email_utils import EmailSearchKeyConverter
38
48
  from upgini.utils.target_utils import balance_undersample
39
49
 
@@ -107,6 +117,7 @@ class Dataset: # (pd.DataFrame):
107
117
  self.meaning_types = meaning_types
108
118
  self.search_keys = search_keys
109
119
  self.unnest_search_keys = unnest_search_keys
120
+ self.ignore_columns = []
110
121
  self.hierarchical_group_keys = []
111
122
  self.hierarchical_subgroup_keys = []
112
123
  self.file_upload_id: Optional[str] = None
@@ -160,6 +171,241 @@ class Dataset: # (pd.DataFrame):
160
171
  if len(self.data) > self.MAX_ROWS:
161
172
  raise ValidationError(self.bundle.get("dataset_too_many_rows_registered").format(self.MAX_ROWS))
162
173
 
174
+ def __rename_columns(self):
175
+ # self.logger.info("Replace restricted symbols in column names")
176
+ new_columns = []
177
+ dup_counter = 0
178
+ for column in self.data.columns:
179
+ if column in [TARGET, EVAL_SET_INDEX, SYSTEM_RECORD_ID, ENTITY_SYSTEM_RECORD_ID, SEARCH_KEY_UNNEST]:
180
+ self.columns_renaming[column] = column
181
+ new_columns.append(column)
182
+ continue
183
+
184
+ new_column = str(column)
185
+ suffix = hashlib.sha256(new_column.encode()).hexdigest()[:6]
186
+ if len(new_column) == 0:
187
+ raise ValidationError(self.bundle.get("dataset_empty_column_names"))
188
+ # db limit for column length
189
+ if len(new_column) > 250:
190
+ new_column = new_column[:250]
191
+
192
+ # make column name unique relative to server features
193
+ new_column = f"{new_column}_{suffix}"
194
+
195
+ new_column = new_column.lower()
196
+
197
+ # if column starts with non alphabetic symbol then add "a" to the beginning of string
198
+ if ord(new_column[0]) not in range(ord("a"), ord("z") + 1):
199
+ new_column = "a" + new_column
200
+
201
+ # replace unsupported characters to "_"
202
+ for idx, c in enumerate(new_column):
203
+ if ord(c) not in range(ord("a"), ord("z") + 1) and ord(c) not in range(ord("0"), ord("9") + 1):
204
+ new_column = new_column[:idx] + "_" + new_column[idx + 1 :]
205
+
206
+ if new_column in new_columns:
207
+ new_column = f"{new_column}_{dup_counter}"
208
+ dup_counter += 1
209
+ new_columns.append(new_column)
210
+
211
+ # self.data.columns.values[col_idx] = new_column
212
+ # self.rename(columns={column: new_column}, inplace=True)
213
+ self.meaning_types = {
214
+ (new_column if key == str(column) else key): value for key, value in self.meaning_types_checked.items()
215
+ }
216
+ self.search_keys = [
217
+ tuple(new_column if key == str(column) else key for key in keys) for keys in self.search_keys_checked
218
+ ]
219
+ self.columns_renaming[new_column] = str(column)
220
+ self.data.columns = new_columns
221
+ self.etalon_def = None
222
+
223
+ def __validate_too_long_string_values(self):
224
+ """Check that string values less than maximum characters for LLM"""
225
+ # self.logger.info("Validate too long string values")
226
+ for col in self.data.columns:
227
+ if is_string_dtype(self.data[col]) or is_object_dtype(self.data[col]):
228
+ max_length: int = self.data[col].astype("str").str.len().max()
229
+ if max_length > self.MAX_STRING_FEATURE_LENGTH:
230
+ self.data[col] = self.data[col].astype("str").str.slice(stop=self.MAX_STRING_FEATURE_LENGTH)
231
+
232
+ def __convert_bools(self):
233
+ """Convert bool columns to string"""
234
+ # self.logger.info("Converting bool to int")
235
+ for col in self.data.columns:
236
+ if is_bool(self.data[col]):
237
+ self.data[col] = self.data[col].astype("str")
238
+
239
+ def __convert_float16(self):
240
+ """Convert float16 to float"""
241
+ # self.logger.info("Converting float16 to float")
242
+ for col in self.data.columns:
243
+ if is_float_dtype(self.data[col]):
244
+ self.data[col] = self.data[col].astype("float64")
245
+
246
+ def __correct_decimal_comma(self):
247
+ """Check DataSet for decimal commas and fix them"""
248
+ # self.logger.info("Correct decimal commas")
249
+ columns_to_fix = find_numbers_with_decimal_comma(self.data)
250
+ if len(columns_to_fix) > 0:
251
+ self.logger.warning(f"Convert strings with decimal comma to float: {columns_to_fix}")
252
+ for col in columns_to_fix:
253
+ self.data[col] = self.data[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
254
+
255
+ @staticmethod
256
+ def _ip_to_int(ip: Optional[_BaseAddress]) -> Optional[int]:
257
+ try:
258
+ if isinstance(ip, (IPv4Address, IPv6Address)):
259
+ return int(ip)
260
+ except Exception:
261
+ pass
262
+
263
+ @staticmethod
264
+ def _ip_to_int_str(ip: Optional[_BaseAddress]) -> Optional[str]:
265
+ try:
266
+ if isinstance(ip, (IPv4Address, IPv6Address)):
267
+ return str(int(ip))
268
+ except Exception:
269
+ pass
270
+
271
+ @staticmethod
272
+ def _safe_ip_parse(ip: Union[str, int, IPv4Address, IPv6Address]) -> Optional[_BaseAddress]:
273
+ try:
274
+ return ip_address(ip)
275
+ except ValueError:
276
+ pass
277
+
278
+ @staticmethod
279
+ def _is_ipv4(ip: Optional[_BaseAddress]):
280
+ return ip is not None and (
281
+ isinstance(ip, IPv4Address) or (isinstance(ip, IPv6Address) and ip.ipv4_mapped is not None)
282
+ )
283
+
284
+ @staticmethod
285
+ def _to_ipv4(ip: Optional[_BaseAddress]) -> Optional[IPv4Address]:
286
+ if isinstance(ip, IPv4Address):
287
+ return ip
288
+ return None
289
+
290
+ @staticmethod
291
+ def _to_ipv6(ip: Optional[_BaseAddress]) -> Optional[IPv6Address]:
292
+ if isinstance(ip, IPv6Address):
293
+ return ip
294
+ if isinstance(ip, IPv4Address):
295
+ return IPv6Address("::ffff:" + str(ip))
296
+ return None
297
+
298
+ def __convert_ip(self):
299
+ """Convert ip address to int"""
300
+ ip = self.etalon_def_checked.get(FileColumnMeaningType.IP_ADDRESS.value)
301
+ if ip is not None and ip in self.data.columns:
302
+ self.logger.info("Convert ip address to int")
303
+ del self.etalon_def[FileColumnMeaningType.IP_ADDRESS.value]
304
+ del self.meaning_types[ip]
305
+ original_ip = self.columns_renaming[ip]
306
+ del self.columns_renaming[ip]
307
+
308
+ search_keys = set()
309
+ for tup in self.search_keys_checked:
310
+ search_keys.update(tup)
311
+ search_keys.remove(ip)
312
+
313
+ self.data[ip] = self.data[ip].apply(self._safe_ip_parse)
314
+ if self.data[ip].isnull().all():
315
+ raise ValidationError(self.bundle.get("invalid_ip").format(ip))
316
+
317
+ ipv4 = ip + "_v4"
318
+ self.data[ipv4] = self.data[ip].apply(self._to_ipv4).apply(self._ip_to_int).astype("Int64")
319
+ self.meaning_types[ipv4] = FileColumnMeaningType.IP_ADDRESS
320
+ self.etalon_def[FileColumnMeaningType.IP_ADDRESS.value] = ipv4
321
+ search_keys.add(ipv4)
322
+ self.columns_renaming[ipv4] = original_ip
323
+
324
+ ipv6 = ip + "_v6"
325
+ self.data[ipv6] = (
326
+ self.data[ip]
327
+ .apply(self._to_ipv6)
328
+ .apply(self._ip_to_int_str)
329
+ .astype("string")
330
+ # .str.replace(".0", "", regex=False)
331
+ )
332
+ self.data = self.data.drop(columns=ip)
333
+ self.meaning_types[ipv6] = FileColumnMeaningType.IPV6_ADDRESS
334
+ self.etalon_def[FileColumnMeaningType.IPV6_ADDRESS.value] = ipv6
335
+ search_keys.add(ipv6)
336
+ self.columns_renaming[ipv6] = original_ip
337
+ self.search_keys = combine_search_keys(search_keys)
338
+
339
+ def __normalize_iso_code(self):
340
+ iso_code = self.etalon_def_checked.get(FileColumnMeaningType.COUNTRY.value)
341
+ if iso_code is not None and iso_code in self.data.columns:
342
+ # self.logger.info("Normalize iso code column")
343
+ self.data[iso_code] = (
344
+ self.data[iso_code]
345
+ .astype("string")
346
+ .str.upper()
347
+ .str.replace(r"[^A-Z]", "", regex=True)
348
+ .str.replace("UK", "GB", regex=False)
349
+ )
350
+ if (self.data[iso_code] == "").all():
351
+ raise ValidationError(self.bundle.get("invalid_country").format(iso_code))
352
+
353
+ def __normalize_postal_code(self):
354
+ postal_code = self.etalon_def_checked.get(FileColumnMeaningType.POSTAL_CODE.value)
355
+ if postal_code is not None and postal_code in self.data.columns:
356
+ # self.logger.info("Normalize postal code")
357
+
358
+ if is_string_dtype(self.data[postal_code]) or is_object_dtype(self.data[postal_code]):
359
+ try:
360
+ self.data[postal_code] = (
361
+ self.data[postal_code].astype("string").astype("Float64").astype("Int64").astype("string")
362
+ )
363
+ except Exception:
364
+ pass
365
+ elif is_float_dtype(self.data[postal_code]):
366
+ self.data[postal_code] = self.data[postal_code].astype("Int64").astype("string")
367
+
368
+ self.data[postal_code] = (
369
+ self.data[postal_code]
370
+ .astype("string")
371
+ .str.upper()
372
+ .str.replace(r"[^0-9A-Z]", "", regex=True) # remove non alphanumeric characters
373
+ .str.replace(r"^0+\B", "", regex=True) # remove leading zeros
374
+ )
375
+ if (self.data[postal_code] == "").all():
376
+ raise ValidationError(self.bundle.get("invalid_postal_code").format(postal_code))
377
+
378
+ def __normalize_hem(self):
379
+ hem = self.etalon_def_checked.get(FileColumnMeaningType.HEM.value)
380
+ if hem is not None and hem in self.data.columns:
381
+ self.data[hem] = self.data[hem].str.lower()
382
+
383
+ def __remove_old_dates(self, silent_mode: bool = False):
384
+ date_column = self.etalon_def_checked.get(FileColumnMeaningType.DATE.value) or self.etalon_def_checked.get(
385
+ FileColumnMeaningType.DATETIME.value
386
+ )
387
+ if date_column is not None and is_numeric_dtype(self.data[date_column]):
388
+ old_subset = self.data[self.data[date_column] < self.MIN_SUPPORTED_DATE_TS]
389
+ if len(old_subset) > 0:
390
+ self.logger.info(f"df before dropping old rows: {self.data.shape}")
391
+ self.data.drop(index=old_subset.index, inplace=True) # type: ignore
392
+ self.logger.info(f"df after dropping old rows: {self.data.shape}")
393
+ if len(self.data) == 0:
394
+ raise ValidationError(self.bundle.get("dataset_all_dates_old"))
395
+ else:
396
+ msg = self.bundle.get("dataset_drop_old_dates")
397
+ self.logger.warning(msg)
398
+ if not silent_mode:
399
+ print(msg)
400
+ self.warning_counter.increment()
401
+
402
+ def __drop_ignore_columns(self):
403
+ """Drop ignore columns"""
404
+ columns_to_drop = list(set(self.data.columns) & set(self.ignore_columns))
405
+ if len(columns_to_drop) > 0:
406
+ # self.logger.info(f"Dropping ignore columns: {self.ignore_columns}")
407
+ self.data.drop(columns_to_drop, axis=1, inplace=True)
408
+
163
409
  def __target_value(self) -> pd.Series:
164
410
  target_column = self.etalon_def_checked.get(FileColumnMeaningType.TARGET.value, "")
165
411
  target: pd.Series = self.data[target_column]
@@ -280,6 +526,52 @@ class Dataset: # (pd.DataFrame):
280
526
  self.data = resampled_data
281
527
  self.logger.info(f"Shape after threshold resampling: {self.data.shape}")
282
528
 
529
+ def __convert_phone(self):
530
+ """Convert phone/msisdn to int"""
531
+ # self.logger.info("Convert phone to int")
532
+ msisdn_column = self.etalon_def_checked.get(FileColumnMeaningType.MSISDN.value)
533
+ country_column = self.etalon_def_checked.get(FileColumnMeaningType.COUNTRY.value)
534
+ if msisdn_column is not None and msisdn_column in self.data.columns:
535
+ normalizer = PhoneNormalizer(self.data, msisdn_column, country_column)
536
+ self.data[msisdn_column] = normalizer.normalize()
537
+ if self.data[msisdn_column].isnull().all():
538
+ raise ValidationError(f"All values of PHONE column `{msisdn_column}` are invalid")
539
+
540
+ def __features(self):
541
+ return [
542
+ f for f, meaning_type in self.meaning_types_checked.items() if meaning_type == FileColumnMeaningType.FEATURE
543
+ ]
544
+
545
+ def __remove_dates_from_features(self, silent_mode: bool = False):
546
+ # self.logger.info("Remove date columns from features")
547
+
548
+ removed_features = []
549
+ for f in self.__features():
550
+ if is_datetime(self.data[f]) or is_period_dtype(self.data[f]):
551
+ removed_features.append(f)
552
+ self.data.drop(columns=f, inplace=True)
553
+ del self.meaning_types_checked[f]
554
+
555
+ if removed_features:
556
+ msg = self.bundle.get("dataset_date_features").format(removed_features)
557
+ self.logger.warning(msg)
558
+ if not silent_mode:
559
+ print(msg)
560
+ self.warning_counter.increment()
561
+
562
+ def __validate_features_count(self):
563
+ if len(self.__features()) > self.MAX_FEATURES_COUNT:
564
+ msg = self.bundle.get("dataset_too_many_features").format(self.MAX_FEATURES_COUNT)
565
+ self.logger.warning(msg)
566
+ raise ValidationError(msg)
567
+
568
+ def __convert_features_types(self):
569
+ # self.logger.info("Convert features to supported data types")
570
+
571
+ for f in self.__features():
572
+ if not is_numeric_dtype(self.data[f]):
573
+ self.data[f] = self.data[f].astype("string")
574
+
283
575
  def __validate_dataset(self, validate_target: bool, silent_mode: bool):
284
576
  """Validate DataSet"""
285
577
  # self.logger.info("validating etalon")
@@ -302,7 +594,7 @@ class Dataset: # (pd.DataFrame):
302
594
  key
303
595
  for search_group in self.search_keys_checked
304
596
  for key in search_group
305
- if not self.columns_renaming.get(key).endswith(EmailSearchKeyConverter.ONE_DOMAIN_SUFFIX)
597
+ if self.columns_renaming.get(key) != EmailSearchKeyConverter.EMAIL_ONE_DOMAIN_COLUMN_NAME
306
598
  }
307
599
  ipv4_column = self.etalon_def_checked.get(FileColumnMeaningType.IP_ADDRESS.value)
308
600
  if (
@@ -410,7 +702,69 @@ class Dataset: # (pd.DataFrame):
410
702
  if len(self.data) == 0:
411
703
  raise ValidationError(self.bundle.get("all_search_keys_invalid"))
412
704
 
705
+ def __validate_meaning_types(self, validate_target: bool):
706
+ # self.logger.info("Validating meaning types")
707
+ if self.meaning_types is None or len(self.meaning_types) == 0:
708
+ raise ValueError(self.bundle.get("dataset_missing_meaning_types"))
709
+
710
+ if SYSTEM_RECORD_ID not in self.data.columns:
711
+ raise ValueError("Internal error")
712
+
713
+ for column in self.meaning_types:
714
+ if column not in self.data.columns:
715
+ raise ValueError(self.bundle.get("dataset_missing_meaning_column").format(column, self.data.columns))
716
+ if validate_target and FileColumnMeaningType.TARGET not in self.meaning_types.values():
717
+ raise ValueError(self.bundle.get("dataset_missing_target"))
718
+
719
+ def __validate_search_keys(self):
720
+ # self.logger.info("Validating search keys")
721
+ if self.search_keys is None or len(self.search_keys) == 0:
722
+ raise ValueError(self.bundle.get("dataset_missing_search_keys"))
723
+ for keys_group in self.search_keys:
724
+ for key in keys_group:
725
+ if key not in self.data.columns:
726
+ showing_columns = set(self.data.columns) - SYSTEM_COLUMNS
727
+ raise ValidationError(
728
+ self.bundle.get("dataset_missing_search_key_column").format(key, showing_columns)
729
+ )
730
+
413
731
  def validate(self, validate_target: bool = True, silent_mode: bool = False):
732
+ # self.logger.info("Validating dataset")
733
+
734
+ self.__validate_search_keys()
735
+
736
+ self.__validate_meaning_types(validate_target=validate_target)
737
+
738
+ self.__drop_ignore_columns()
739
+
740
+ self.__rename_columns()
741
+
742
+ self.__remove_dates_from_features(silent_mode)
743
+
744
+ self.__validate_features_count()
745
+
746
+ self.__validate_too_long_string_values()
747
+
748
+ self.__convert_bools()
749
+
750
+ self.__convert_float16()
751
+
752
+ self.__correct_decimal_comma()
753
+
754
+ self.__remove_old_dates(silent_mode)
755
+
756
+ self.__convert_ip()
757
+
758
+ self.__convert_phone()
759
+
760
+ self.__normalize_iso_code()
761
+
762
+ self.__normalize_postal_code()
763
+
764
+ self.__normalize_hem()
765
+
766
+ self.__convert_features_types()
767
+
414
768
  self.__validate_dataset(validate_target, silent_mode)
415
769
 
416
770
  if validate_target:
@@ -428,39 +782,38 @@ class Dataset: # (pd.DataFrame):
428
782
  # self.logger.info("Constructing dataset metadata")
429
783
  columns = []
430
784
  for index, (column_name, column_type) in enumerate(zip(self.data.columns, self.data.dtypes)):
431
- if column_name in self.meaning_types_checked:
432
- meaning_type = self.meaning_types_checked[column_name]
433
- # Temporary workaround while backend doesn't support datetime
434
- if meaning_type == FileColumnMeaningType.DATETIME:
435
- meaning_type = FileColumnMeaningType.DATE
436
- else:
437
- meaning_type = FileColumnMeaningType.FEATURE
438
- if meaning_type in {
439
- FileColumnMeaningType.DATE,
440
- FileColumnMeaningType.DATETIME,
441
- # FileColumnMeaningType.IP_ADDRESS,
442
- }:
443
- min_value = self.data[column_name].astype("Int64").min()
444
- max_value = self.data[column_name].astype("Int64").max()
445
- min_max_values = NumericInterval(
446
- minValue=min_value,
447
- maxValue=max_value,
785
+ if column_name not in self.ignore_columns:
786
+ if column_name in self.meaning_types_checked:
787
+ meaning_type = self.meaning_types_checked[column_name]
788
+ # Temporary workaround while backend doesn't support datetime
789
+ if meaning_type == FileColumnMeaningType.DATETIME:
790
+ meaning_type = FileColumnMeaningType.DATE
791
+ else:
792
+ meaning_type = FileColumnMeaningType.FEATURE
793
+ if meaning_type in {
794
+ FileColumnMeaningType.DATE,
795
+ FileColumnMeaningType.DATETIME,
796
+ # FileColumnMeaningType.IP_ADDRESS,
797
+ }:
798
+ min_max_values = NumericInterval(
799
+ minValue=self.data[column_name].astype("Int64").min(),
800
+ maxValue=self.data[column_name].astype("Int64").max(),
801
+ )
802
+ else:
803
+ min_max_values = None
804
+ column_meta = FileColumnMetadata(
805
+ index=index,
806
+ name=column_name,
807
+ originalName=self.columns_renaming.get(column_name) or column_name,
808
+ dataType=self.__get_data_type(column_type, column_name),
809
+ meaningType=meaning_type,
810
+ minMaxValues=min_max_values,
448
811
  )
449
- else:
450
- min_max_values = None
451
- column_meta = FileColumnMetadata(
452
- index=index,
453
- name=column_name,
454
- originalName=self.columns_renaming.get(column_name) or column_name,
455
- dataType=self.__get_data_type(column_type, column_name),
456
- meaningType=meaning_type,
457
- minMaxValues=min_max_values,
458
- )
459
- if self.unnest_search_keys and column_meta.originalName in self.unnest_search_keys:
460
- column_meta.isUnnest = True
461
- column_meta.unnestKeyNames = self.unnest_search_keys[column_meta.originalName]
812
+ if self.unnest_search_keys and column_meta.originalName in self.unnest_search_keys:
813
+ column_meta.isUnnest = True
814
+ column_meta.unnestKeyNames = self.unnest_search_keys[column_meta.originalName]
462
815
 
463
- columns.append(column_meta)
816
+ columns.append(column_meta)
464
817
 
465
818
  return FileMetadata(
466
819
  name=self.dataset_name,