upgini 1.1.315__py3-none-any.whl → 1.1.315a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

upgini/metadata.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from enum import Enum
4
- from typing import Dict, List, Optional, Set
4
+ from typing import Dict, List, Optional, Set, Union
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
@@ -113,6 +113,21 @@ class SearchKey(Enum):
113
113
  if meaning_type == FileColumnMeaningType.MSISDN_RANGE_TO:
114
114
  return SearchKey.MSISDN_RANGE_TO
115
115
 
116
+ @staticmethod
117
+ def find_key(search_keys: Dict[str, SearchKey], keys: Union[SearchKey, List[SearchKey]]) -> Optional[SearchKey]:
118
+ if isinstance(keys, SearchKey):
119
+ keys = [keys]
120
+ for col, key_type in search_keys.items():
121
+ if key_type in keys:
122
+ return col
123
+ return None
124
+
125
+ @staticmethod
126
+ def find_all_keys(search_keys: Dict[str, SearchKey], keys: Union[SearchKey, List[SearchKey]]) -> List[SearchKey]:
127
+ if isinstance(keys, SearchKey):
128
+ keys = [keys]
129
+ return [col for col, key_type in search_keys.items() if key_type in keys]
130
+
116
131
 
117
132
  class DataType(Enum):
118
133
  INT = "INT"
@@ -0,0 +1,203 @@
1
+ import hashlib
2
+ from logging import Logger, getLogger
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pandas.api.types import is_bool_dtype as is_bool
8
+ from pandas.api.types import is_datetime64_any_dtype as is_datetime
9
+ from pandas.api.types import (
10
+ is_float_dtype,
11
+ is_numeric_dtype,
12
+ is_object_dtype,
13
+ is_period_dtype,
14
+ is_string_dtype,
15
+ )
16
+
17
+ from upgini.errors import ValidationError
18
+ from upgini.metadata import (
19
+ ENTITY_SYSTEM_RECORD_ID,
20
+ EVAL_SET_INDEX,
21
+ SEARCH_KEY_UNNEST,
22
+ SYSTEM_RECORD_ID,
23
+ TARGET,
24
+ SearchKey,
25
+ )
26
+ from upgini.resource_bundle import ResourceBundle, get_custom_bundle
27
+ from upgini.utils import find_numbers_with_decimal_comma
28
+ from upgini.utils.datetime_utils import DateTimeSearchKeyConverter
29
+ from upgini.utils.phone_utils import PhoneSearchKeyConverter
30
+ from upgini.utils.warning_counter import WarningCounter
31
+
32
+
33
+ class Normalizer:
34
+
35
+ MAX_STRING_FEATURE_LENGTH = 24573
36
+
37
+ def __init__(
38
+ self,
39
+ search_keys: Dict[str, SearchKey],
40
+ generated_features: List[str],
41
+ bundle: ResourceBundle = None,
42
+ logger: Logger = None,
43
+ warnings_counter: WarningCounter = None,
44
+ silent_mode=False,
45
+ ):
46
+ self.search_keys = search_keys
47
+ self.generated_features = generated_features
48
+ self.bundle = bundle or get_custom_bundle()
49
+ self.logger = logger or getLogger()
50
+ self.warnings_counter = warnings_counter or WarningCounter()
51
+ self.silent_mode = silent_mode
52
+ self.columns_renaming = {}
53
+
54
+ def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
55
+ df = df.copy()
56
+ df = self._rename_columns(df)
57
+
58
+ df = self._remove_dates_from_features(df)
59
+
60
+ df = self._cut_too_long_string_values(df)
61
+
62
+ df = self._convert_bools(df)
63
+
64
+ df = self._convert_float16(df)
65
+
66
+ df = self._correct_decimal_comma(df)
67
+
68
+ df = self._convert_phone_numbers(df)
69
+
70
+ df = self.__convert_features_types(df)
71
+
72
+ return df
73
+
74
+ def _rename_columns(self, df: pd.DataFrame):
75
+ # logger.info("Replace restricted symbols in column names")
76
+ new_columns = []
77
+ dup_counter = 0
78
+ for column in df.columns:
79
+ if column in [
80
+ TARGET,
81
+ EVAL_SET_INDEX,
82
+ SYSTEM_RECORD_ID,
83
+ ENTITY_SYSTEM_RECORD_ID,
84
+ SEARCH_KEY_UNNEST,
85
+ DateTimeSearchKeyConverter.DATETIME_COL,
86
+ ] + self.generated_features:
87
+ self.columns_renaming[column] = column
88
+ new_columns.append(column)
89
+ continue
90
+
91
+ new_column = str(column)
92
+ suffix = hashlib.sha256(new_column.encode()).hexdigest()[:6]
93
+ if len(new_column) == 0:
94
+ raise ValidationError(self.bundle.get("dataset_empty_column_names"))
95
+ # db limit for column length
96
+ if len(new_column) > 250:
97
+ new_column = new_column[:250]
98
+
99
+ # make column name unique relative to server features
100
+ new_column = f"{new_column}_{suffix}"
101
+
102
+ new_column = new_column.lower()
103
+
104
+ # if column starts with non alphabetic symbol then add "a" to the beginning of string
105
+ if ord(new_column[0]) not in range(ord("a"), ord("z") + 1):
106
+ new_column = "a" + new_column
107
+
108
+ # replace unsupported characters to "_"
109
+ for idx, c in enumerate(new_column):
110
+ if ord(c) not in range(ord("a"), ord("z") + 1) and ord(c) not in range(ord("0"), ord("9") + 1):
111
+ new_column = new_column[:idx] + "_" + new_column[idx + 1 :]
112
+
113
+ if new_column in new_columns:
114
+ new_column = f"{new_column}_{dup_counter}"
115
+ dup_counter += 1
116
+ new_columns.append(new_column)
117
+
118
+ # df.columns.values[col_idx] = new_column
119
+ # rename(columns={column: new_column}, inplace=True)
120
+
121
+ if new_column != column and column in self.search_keys:
122
+ self.search_keys[new_column] = self.search_keys[column]
123
+ del self.search_keys[column]
124
+ self.columns_renaming[new_column] = str(column)
125
+ df.columns = new_columns
126
+ return df
127
+
128
+ def _get_features(self, df: pd.DataFrame) -> List[str]:
129
+ system_columns = [ENTITY_SYSTEM_RECORD_ID, EVAL_SET_INDEX, SEARCH_KEY_UNNEST, SYSTEM_RECORD_ID, TARGET]
130
+ features = set(df.columns) - set(self.search_keys.keys()) - set(system_columns)
131
+ return sorted(list(features))
132
+
133
+ def _remove_dates_from_features(self, df: pd.DataFrame):
134
+ features = self._get_features(df)
135
+
136
+ removed_features = []
137
+ for f in features:
138
+ if is_datetime(df[f]) or is_period_dtype(df[f]):
139
+ removed_features.append(f)
140
+ df.drop(columns=f, inplace=True)
141
+
142
+ if removed_features:
143
+ msg = self.bundle.get("dataset_date_features").format(removed_features)
144
+ self.logger.warning(msg)
145
+ if not self.silent_mode:
146
+ print(msg)
147
+ self.warnings_counter.increment()
148
+
149
+ return df
150
+
151
+ def _cut_too_long_string_values(self, df: pd.DataFrame):
152
+ """Check that string values less than maximum characters for LLM"""
153
+ # logger.info("Validate too long string values")
154
+ for col in df.columns:
155
+ if is_string_dtype(df[col]) or is_object_dtype(df[col]):
156
+ max_length: int = df[col].astype("str").str.len().max()
157
+ if max_length > self.MAX_STRING_FEATURE_LENGTH:
158
+ df[col] = df[col].astype("str").str.slice(stop=self.MAX_STRING_FEATURE_LENGTH)
159
+
160
+ return df
161
+
162
+ @staticmethod
163
+ def _convert_bools(df: pd.DataFrame):
164
+ """Convert bool columns to string"""
165
+ # logger.info("Converting bool to int")
166
+ for col in df.columns:
167
+ if is_bool(df[col]):
168
+ df[col] = df[col].astype("str")
169
+ return df
170
+
171
+ @staticmethod
172
+ def _convert_float16(df: pd.DataFrame):
173
+ """Convert float16 to float"""
174
+ # logger.info("Converting float16 to float")
175
+ for col in df.columns:
176
+ if is_float_dtype(df[col]):
177
+ df[col] = df[col].astype("float64")
178
+ return df
179
+
180
+ def _correct_decimal_comma(self, df: pd.DataFrame):
181
+ """Check DataSet for decimal commas and fix them"""
182
+ # logger.info("Correct decimal commas")
183
+ columns_to_fix = find_numbers_with_decimal_comma(df)
184
+ if len(columns_to_fix) > 0:
185
+ self.logger.warning(f"Convert strings with decimal comma to float: {columns_to_fix}")
186
+ for col in columns_to_fix:
187
+ df[col] = df[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
188
+ return df
189
+
190
+ def _convert_phone_numbers(self, df: pd.DataFrame) -> pd.DataFrame:
191
+ maybe_country_col = SearchKey.find_key(self.search_keys, SearchKey.COUNTRY)
192
+ for phone_col in SearchKey.find_all_keys(self.search_keys, SearchKey.PHONE):
193
+ converter = PhoneSearchKeyConverter(phone_col, maybe_country_col)
194
+ df = converter.convert(df)
195
+ return df
196
+
197
+ def __convert_features_types(self, df: pd.DataFrame):
198
+ # self.logger.info("Convert features to supported data types")
199
+
200
+ for f in self._get_features(df):
201
+ if not is_numeric_dtype(df[f]):
202
+ df[f] = df[f].astype("string")
203
+ return df
@@ -4,6 +4,22 @@ from pandas.api.types import is_object_dtype, is_string_dtype
4
4
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
5
5
 
6
6
 
7
+ class CountrySearchKeyConverter:
8
+
9
+ def __init__(self, country_col: str):
10
+ self.country_col = country_col
11
+
12
+ def convert(self, df: pd.DataFrame) -> pd.DataFrame:
13
+ df[self.country_col] = (
14
+ df[self.country_col]
15
+ .astype("string")
16
+ .str.upper()
17
+ .str.replace(r"[^A-Z]", "", regex=True)
18
+ .str.replace("UK", "GB", regex=False)
19
+ )
20
+ return df
21
+
22
+
7
23
  class CountrySearchKeyDetector(BaseSearchKeyDetector):
8
24
  def _is_search_key_by_name(self, column_name: str) -> bool:
9
25
  return "country" in str(column_name).lower()
@@ -1,18 +1,16 @@
1
1
  import datetime
2
2
  import logging
3
3
  import re
4
+ import pytz
4
5
  from typing import Dict, List, Optional
5
6
 
6
7
  import numpy as np
7
8
  import pandas as pd
8
9
  from dateutil.relativedelta import relativedelta
9
- from pandas.api.types import (
10
- is_numeric_dtype,
11
- is_period_dtype,
12
- )
10
+ from pandas.api.types import is_numeric_dtype, is_period_dtype
13
11
 
14
12
  from upgini.errors import ValidationError
15
- from upgini.metadata import SearchKey
13
+ from upgini.metadata import EVAL_SET_INDEX, SearchKey
16
14
  from upgini.resource_bundle import ResourceBundle, get_custom_bundle
17
15
  from upgini.utils.warning_counter import WarningCounter
18
16
 
@@ -31,18 +29,22 @@ DATE_FORMATS = [
31
29
  "%Y-%m-%dT%H:%M:%S.%f",
32
30
  ]
33
31
 
34
- DATETIME_PATTERN = r"^[\d\s\.\-:T/]+$"
32
+ DATETIME_PATTERN = r"^[\d\s\.\-:T/+]+$"
35
33
 
36
34
 
37
35
  class DateTimeSearchKeyConverter:
38
36
  DATETIME_COL = "_date_time"
37
+ # MIN_SUPPORTED_DATE_TS = datetime.datetime(1999, 12, 31) # 946684800000 # 2000-01-01
38
+ MIN_SUPPORTED_DATE_TS = pd.to_datetime(datetime.datetime(1999, 12, 31)).tz_localize(None)
39
39
 
40
40
  def __init__(
41
41
  self,
42
42
  date_column: str,
43
43
  date_format: Optional[str] = None,
44
44
  logger: Optional[logging.Logger] = None,
45
- bundle: ResourceBundle = None,
45
+ bundle: Optional[ResourceBundle] = None,
46
+ warnings_counter: Optional[WarningCounter] = None,
47
+ silent_mode=False,
46
48
  ):
47
49
  self.date_column = date_column
48
50
  self.date_format = date_format
@@ -53,6 +55,8 @@ class DateTimeSearchKeyConverter:
53
55
  self.logger.setLevel("FATAL")
54
56
  self.generated_features: List[str] = []
55
57
  self.bundle = bundle or get_custom_bundle()
58
+ self.warnings_counter = warnings_counter or WarningCounter()
59
+ self.silent_mode = silent_mode
56
60
 
57
61
  @staticmethod
58
62
  def _int_to_opt(i: int) -> Optional[int]:
@@ -88,13 +92,13 @@ class DateTimeSearchKeyConverter:
88
92
  # 315532801000 - 2524608001000 - milliseconds
89
93
  # 315532801000000 - 2524608001000000 - microseconds
90
94
  # 315532801000000000 - 2524608001000000000 - nanoseconds
91
- if df[self.date_column].apply(lambda x: 10 ** 16 < x).all():
95
+ if df[self.date_column].apply(lambda x: 10**16 < x).all():
92
96
  df[self.date_column] = pd.to_datetime(df[self.date_column], unit="ns")
93
- elif df[self.date_column].apply(lambda x: 10 ** 14 < x < 10 ** 16).all():
97
+ elif df[self.date_column].apply(lambda x: 10**14 < x < 10**16).all():
94
98
  df[self.date_column] = pd.to_datetime(df[self.date_column], unit="us")
95
- elif df[self.date_column].apply(lambda x: 10 ** 11 < x < 10 ** 14).all():
99
+ elif df[self.date_column].apply(lambda x: 10**11 < x < 10**14).all():
96
100
  df[self.date_column] = pd.to_datetime(df[self.date_column], unit="ms")
97
- elif df[self.date_column].apply(lambda x: 0 < x < 10 ** 11).all():
101
+ elif df[self.date_column].apply(lambda x: 0 < x < 10**11).all():
98
102
  df[self.date_column] = pd.to_datetime(df[self.date_column], unit="s")
99
103
  else:
100
104
  msg = self.bundle.get("unsupported_date_type").format(self.date_column)
@@ -108,6 +112,9 @@ class DateTimeSearchKeyConverter:
108
112
  # as additional features
109
113
  seconds = "datetime_seconds"
110
114
  df[self.date_column] = df[self.date_column].dt.tz_localize(None)
115
+
116
+ df = self.clean_old_dates(df)
117
+
111
118
  df[seconds] = (df[self.date_column] - df[self.date_column].dt.floor("D")).dt.seconds
112
119
 
113
120
  seconds_without_na = df[seconds].dropna()
@@ -152,6 +159,19 @@ class DateTimeSearchKeyConverter:
152
159
  except ValueError:
153
160
  raise ValidationError(self.bundle.get("invalid_date_format").format(self.date_column))
154
161
 
162
+ def clean_old_dates(self, df: pd.DataFrame) -> pd.DataFrame:
163
+ condition = df[self.date_column] <= self.MIN_SUPPORTED_DATE_TS
164
+ old_subset = df[condition]
165
+ if len(old_subset) > 0:
166
+ df.loc[condition, self.date_column] = None
167
+ self.logger.info(f"Set to None: {len(old_subset)} of {len(df)} rows because they are before 2000-01-01")
168
+ msg = self.bundle.get("dataset_drop_old_dates")
169
+ self.logger.warning(msg)
170
+ if not self.silent_mode:
171
+ print(msg)
172
+ self.warnings_counter.increment()
173
+ return df
174
+
155
175
 
156
176
  def is_time_series(df: pd.DataFrame, date_col: str) -> bool:
157
177
  try:
@@ -238,16 +258,18 @@ def is_blocked_time_series(df: pd.DataFrame, date_col: str, search_keys: List[st
238
258
 
239
259
 
240
260
  def validate_dates_distribution(
241
- X: pd.DataFrame,
261
+ df: pd.DataFrame,
242
262
  search_keys: Dict[str, SearchKey],
243
263
  logger: Optional[logging.Logger] = None,
244
264
  bundle: Optional[ResourceBundle] = None,
245
265
  warning_counter: Optional[WarningCounter] = None,
246
266
  ):
247
- maybe_date_col = None
248
- for key, key_type in search_keys.items():
249
- if key_type in [SearchKey.DATE, SearchKey.DATETIME]:
250
- maybe_date_col = key
267
+ maybe_date_col = SearchKey.find_key(search_keys, [SearchKey.DATE, SearchKey.DATETIME])
268
+
269
+ if EVAL_SET_INDEX in df.columns:
270
+ X = df.query(f"{EVAL_SET_INDEX} == 0")
271
+ else:
272
+ X = df
251
273
 
252
274
  if maybe_date_col is None:
253
275
  for col in X.columns:
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  from pandas.api.types import is_object_dtype, is_string_dtype
8
8
 
9
9
  from upgini.metadata import SearchKey
10
- from upgini.resource_bundle import bundle
10
+ from upgini.resource_bundle import ResourceBundle, get_custom_bundle
11
11
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
12
12
 
13
13
  EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9.!#$%&’*+/=?^_`{|}~-]+@[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*$")
@@ -28,29 +28,53 @@ class EmailSearchKeyDetector(BaseSearchKeyDetector):
28
28
  return is_email_count / all_count > 0.1
29
29
 
30
30
 
31
+ class EmailDomainGenerator:
32
+ DOMAIN_SUFFIX = "_domain"
33
+
34
+ def __init__(self, email_columns: List[str]):
35
+ self.email_columns = email_columns
36
+ self.generated_features = []
37
+
38
+ def generate(self, df: pd.DataFrame) -> pd.DataFrame:
39
+ for email_col in self.email_columns:
40
+ domain_feature = email_col + self.DOMAIN_SUFFIX
41
+ df[domain_feature] = df[email_col].apply(self._email_to_domain)
42
+ self.generated_features.append(domain_feature)
43
+ return df
44
+
45
+ @staticmethod
46
+ def _email_to_domain(email: str) -> Optional[str]:
47
+ if email is not None and isinstance(email, str) and "@" in email:
48
+ name_and_domain = email.split("@")
49
+ if len(name_and_domain) == 2 and len(name_and_domain[1]) > 0:
50
+ return name_and_domain[1]
51
+
52
+
31
53
  class EmailSearchKeyConverter:
32
- HEM_COLUMN_NAME = "hashed_email"
33
- DOMAIN_COLUMN_NAME = "email_domain"
34
- EMAIL_ONE_DOMAIN_COLUMN_NAME = "email_one_domain"
54
+ HEM_SUFFIX = "_hem"
55
+ ONE_DOMAIN_SUFFIX = "_one_domain"
35
56
 
36
57
  def __init__(
37
58
  self,
38
59
  email_column: str,
39
60
  hem_column: Optional[str],
40
61
  search_keys: Dict[str, SearchKey],
62
+ columns_renaming: Dict[str, str],
41
63
  unnest_search_keys: Optional[List[str]] = None,
64
+ bundle: Optional[ResourceBundle] = None,
42
65
  logger: Optional[logging.Logger] = None,
43
66
  ):
44
67
  self.email_column = email_column
45
68
  self.hem_column = hem_column
46
69
  self.search_keys = search_keys
70
+ self.columns_renaming = columns_renaming
47
71
  self.unnest_search_keys = unnest_search_keys
72
+ self.bundle = bundle or get_custom_bundle()
48
73
  if logger is not None:
49
74
  self.logger = logger
50
75
  else:
51
76
  self.logger = logging.getLogger()
52
77
  self.logger.setLevel("FATAL")
53
- self.generated_features: List[str] = []
54
78
  self.email_converted_to_hem = False
55
79
 
56
80
  @staticmethod
@@ -61,7 +85,7 @@ class EmailSearchKeyConverter:
61
85
  if not EMAIL_REGEX.fullmatch(email):
62
86
  return None
63
87
 
64
- return sha256(email.lower().encode("utf-8")).hexdigest()
88
+ return sha256(email.lower().encode("utf-8")).hexdigest().lower()
65
89
 
66
90
  @staticmethod
67
91
  def _email_to_one_domain(email: str) -> Optional[str]:
@@ -72,28 +96,36 @@ class EmailSearchKeyConverter:
72
96
 
73
97
  def convert(self, df: pd.DataFrame) -> pd.DataFrame:
74
98
  df = df.copy()
99
+ original_email_column = self.columns_renaming[self.email_column]
75
100
  if self.hem_column is None:
76
- df[self.HEM_COLUMN_NAME] = df[self.email_column].apply(self._email_to_hem)
77
- if df[self.HEM_COLUMN_NAME].isna().all():
78
- msg = bundle.get("all_emails_invalid").format(self.email_column)
101
+ hem_name = self.email_column + self.HEM_SUFFIX
102
+ df[hem_name] = df[self.email_column].apply(self._email_to_hem)
103
+ if df[hem_name].isna().all():
104
+ msg = self.bundle.get("all_emails_invalid").format(self.email_column)
79
105
  print(msg)
80
106
  self.logger.warning(msg)
81
- df = df.drop(columns=self.HEM_COLUMN_NAME)
107
+ df = df.drop(columns=hem_name)
82
108
  del self.search_keys[self.email_column]
83
109
  return df
84
- self.search_keys[self.HEM_COLUMN_NAME] = SearchKey.HEM
85
- self.unnest_search_keys.append(self.HEM_COLUMN_NAME)
110
+ self.search_keys[hem_name] = SearchKey.HEM
111
+ if self.email_column in self.unnest_search_keys:
112
+ self.unnest_search_keys.append(hem_name)
113
+ self.columns_renaming[hem_name] = original_email_column # it could be upgini_email_unnest...
86
114
  self.email_converted_to_hem = True
115
+ else:
116
+ df[self.hem_column] = df[self.hem_column].astype("string").str.lower()
87
117
 
88
118
  del self.search_keys[self.email_column]
89
119
  if self.email_column in self.unnest_search_keys:
90
120
  self.unnest_search_keys.remove(self.email_column)
91
121
 
92
- df[self.EMAIL_ONE_DOMAIN_COLUMN_NAME] = df[self.email_column].apply(self._email_to_one_domain)
93
-
94
- self.search_keys[self.EMAIL_ONE_DOMAIN_COLUMN_NAME] = SearchKey.EMAIL_ONE_DOMAIN
122
+ one_domain_name = self.email_column + self.ONE_DOMAIN_SUFFIX
123
+ df[one_domain_name] = df[self.email_column].apply(self._email_to_one_domain)
124
+ self.columns_renaming[one_domain_name] = original_email_column
125
+ self.search_keys[one_domain_name] = SearchKey.EMAIL_ONE_DOMAIN
95
126
 
96
- df[self.DOMAIN_COLUMN_NAME] = df[self.EMAIL_ONE_DOMAIN_COLUMN_NAME].str[1:]
97
- self.generated_features.append(self.DOMAIN_COLUMN_NAME)
127
+ if self.email_converted_to_hem:
128
+ df = df.drop(columns=self.email_column)
129
+ del self.columns_renaming[self.email_column]
98
130
 
99
131
  return df
upgini/utils/ip_utils.py CHANGED
@@ -1,15 +1,114 @@
1
1
  import logging
2
- from typing import Dict, List, Optional
2
+ from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
3
+ from typing import Dict, List, Optional, Union
3
4
 
4
5
  import pandas as pd
5
6
  from requests import get
6
7
 
8
+ from upgini.errors import ValidationError
7
9
  from upgini.metadata import SearchKey
10
+ from upgini.resource_bundle import ResourceBundle, get_custom_bundle
8
11
 
9
12
  # from upgini.resource_bundle import bundle
10
13
  # from upgini.utils.track_info import get_track_metrics
11
14
 
12
15
 
16
+ class IpSearchKeyConverter:
17
+ def __init__(
18
+ self,
19
+ ip_column: str,
20
+ search_keys: Dict[str, SearchKey],
21
+ columns_renaming: Dict[str, str],
22
+ unnest_search_keys: Optional[List[str]] = None,
23
+ bundle: Optional[ResourceBundle] = None,
24
+ logger: Optional[logging.Logger] = None,
25
+ ):
26
+ self.ip_column = ip_column
27
+ self.search_keys = search_keys
28
+ self.columns_renaming = columns_renaming
29
+ self.unnest_search_keys = unnest_search_keys
30
+ self.bundle = bundle or get_custom_bundle()
31
+ if logger is not None:
32
+ self.logger = logger
33
+ else:
34
+ self.logger = logging.getLogger()
35
+ self.logger.setLevel("FATAL")
36
+
37
+ @staticmethod
38
+ def _ip_to_int(ip: Optional[_BaseAddress]) -> Optional[int]:
39
+ try:
40
+ if isinstance(ip, (IPv4Address, IPv6Address)):
41
+ return int(ip)
42
+ except Exception:
43
+ pass
44
+
45
+ @staticmethod
46
+ def _ip_to_int_str(ip: Optional[_BaseAddress]) -> Optional[str]:
47
+ try:
48
+ if isinstance(ip, (IPv4Address, IPv6Address)):
49
+ return str(int(ip))
50
+ except Exception:
51
+ pass
52
+
53
+ @staticmethod
54
+ def _safe_ip_parse(ip: Union[str, int, IPv4Address, IPv6Address]) -> Optional[_BaseAddress]:
55
+ try:
56
+ return ip_address(ip)
57
+ except ValueError:
58
+ pass
59
+
60
+ @staticmethod
61
+ def _is_ipv4(ip: Optional[_BaseAddress]):
62
+ return ip is not None and (
63
+ isinstance(ip, IPv4Address) or (isinstance(ip, IPv6Address) and ip.ipv4_mapped is not None)
64
+ )
65
+
66
+ @staticmethod
67
+ def _to_ipv4(ip: Optional[_BaseAddress]) -> Optional[IPv4Address]:
68
+ if isinstance(ip, IPv4Address):
69
+ return ip
70
+ return None
71
+
72
+ @staticmethod
73
+ def _to_ipv6(ip: Optional[_BaseAddress]) -> Optional[IPv6Address]:
74
+ if isinstance(ip, IPv6Address):
75
+ return ip
76
+ if isinstance(ip, IPv4Address):
77
+ return IPv6Address("::ffff:" + str(ip))
78
+ return None
79
+
80
+ def convert(self, df: pd.DataFrame) -> pd.DataFrame:
81
+ """Convert ip address to int"""
82
+ self.logger.info("Convert ip address to int")
83
+ original_ip = self.columns_renaming[self.ip_column]
84
+
85
+ df[self.ip_column] = df[self.ip_column].apply(self._safe_ip_parse)
86
+ if df[self.ip_column].isnull().all():
87
+ raise ValidationError(self.bundle.get("invalid_ip").format(self.ip_column))
88
+
89
+ # legacy support
90
+ ipv4 = self.ip_column + "_v4"
91
+ df[ipv4] = df[self.ip_column].apply(self._to_ipv4).apply(self._ip_to_int).astype("Int64")
92
+ self.search_keys[ipv4] = SearchKey.IP
93
+ self.columns_renaming[ipv4] = original_ip
94
+
95
+ ipv6 = self.ip_column + "_v6"
96
+ df[ipv6] = (
97
+ df[self.ip_column]
98
+ .apply(self._to_ipv6)
99
+ .apply(self._ip_to_int_str)
100
+ .astype("string")
101
+ # .str.replace(".0", "", regex=False)
102
+ )
103
+ df = df.drop(columns=self.ip_column)
104
+ del self.search_keys[self.ip_column]
105
+ del self.columns_renaming[self.ip_column]
106
+ self.search_keys[ipv6] = SearchKey.IPV6_ADDRESS
107
+ self.columns_renaming[ipv6] = original_ip # could be __unnest_ip...
108
+
109
+ return df
110
+
111
+
13
112
  class IpToCountrySearchKeyConverter:
14
113
  url = "http://ip-api.com/json/{}"
15
114