upgini 1.1.315a3579.dev1__py3-none-any.whl → 1.1.316a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of upgini might be problematic. Click here for more details.

upgini/http.py CHANGED
@@ -39,17 +39,6 @@ from upgini.metadata import (
39
39
  from upgini.resource_bundle import bundle
40
40
  from upgini.utils.track_info import get_track_metrics
41
41
 
42
- # try:
43
- # from importlib.metadata import version # type: ignore
44
-
45
- # __version__ = version("upgini")
46
- # except ImportError:
47
- # try:
48
- # from importlib_metadata import version # type: ignore
49
-
50
- # __version__ = version("upgini")
51
- # except ImportError:
52
- # __version__ = "Upgini wasn't installed"
53
42
 
54
43
  UPGINI_URL: str = "UPGINI_URL"
55
44
  UPGINI_API_KEY: str = "UPGINI_API_KEY"
@@ -452,18 +441,18 @@ class _RestClient:
452
441
  content = file.read()
453
442
  md5_hash.update(content)
454
443
  digest = md5_hash.hexdigest()
455
- metadata_with_md5 = metadata.copy(update={"checksumMD5": digest})
444
+ metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
456
445
 
457
446
  digest_sha256 = hashlib.sha256(
458
447
  pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
459
448
  ).hexdigest()
460
- metadata_with_md5 = metadata_with_md5.copy(update={"digest": digest_sha256})
449
+ metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
461
450
 
462
451
  with open(file_path, "rb") as file:
463
452
  files = {
464
453
  "metadata": (
465
454
  "metadata.json",
466
- metadata_with_md5.json(exclude_none=True).encode(),
455
+ metadata_with_md5.model_dump_json(exclude_none=True).encode(),
467
456
  "application/json",
468
457
  ),
469
458
  "tracking": (
@@ -471,13 +460,13 @@ class _RestClient:
471
460
  dumps(track_metrics).encode(),
472
461
  "application/json",
473
462
  ),
474
- "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
463
+ "metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
475
464
  "file": (metadata_with_md5.name, file, "application/octet-stream"),
476
465
  }
477
466
  if search_customization is not None:
478
467
  files["customization"] = (
479
468
  "customization.json",
480
- search_customization.json(exclude_none=True).encode(),
469
+ search_customization.model_dump_json(exclude_none=True).encode(),
481
470
  "application/json",
482
471
  )
483
472
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
@@ -492,7 +481,7 @@ class _RestClient:
492
481
  def check_uploaded_file_v2(self, trace_id: str, file_upload_id: str, metadata: FileMetadata) -> bool:
493
482
  api_path = self.CHECK_UPLOADED_FILE_URL_FMT_V2.format(file_upload_id)
494
483
  response = self._with_unauth_retry(
495
- lambda: self._send_post_req(api_path, trace_id, metadata.json(exclude_none=True))
484
+ lambda: self._send_post_req(api_path, trace_id, metadata.model_dump_json(exclude_none=True))
496
485
  )
497
486
  return bool(response)
498
487
 
@@ -506,11 +495,11 @@ class _RestClient:
506
495
  ) -> SearchTaskResponse:
507
496
  api_path = self.INITIAL_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id)
508
497
  files = {
509
- "metadata": ("metadata.json", metadata.json(exclude_none=True).encode(), "application/json"),
510
- "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
498
+ "metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
499
+ "metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
511
500
  }
512
501
  if search_customization is not None:
513
- files["customization"] = search_customization.json(exclude_none=True).encode()
502
+ files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
514
503
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
515
504
  response = self._with_unauth_retry(
516
505
  lambda: self._send_post_file_req_v2(
@@ -536,18 +525,18 @@ class _RestClient:
536
525
  content = file.read()
537
526
  md5_hash.update(content)
538
527
  digest = md5_hash.hexdigest()
539
- metadata_with_md5 = metadata.copy(update={"checksumMD5": digest})
528
+ metadata_with_md5 = metadata.model_copy(update={"checksumMD5": digest})
540
529
 
541
530
  digest_sha256 = hashlib.sha256(
542
531
  pd.util.hash_pandas_object(pd.read_parquet(file_path, engine="fastparquet")).values
543
532
  ).hexdigest()
544
- metadata_with_md5 = metadata_with_md5.copy(update={"digest": digest_sha256})
533
+ metadata_with_md5 = metadata_with_md5.model_copy(update={"digest": digest_sha256})
545
534
 
546
535
  with open(file_path, "rb") as file:
547
536
  files = {
548
537
  "metadata": (
549
538
  "metadata.json",
550
- metadata_with_md5.json(exclude_none=True).encode(),
539
+ metadata_with_md5.model_dump_json(exclude_none=True).encode(),
551
540
  "application/json",
552
541
  ),
553
542
  "tracking": (
@@ -555,13 +544,13 @@ class _RestClient:
555
544
  dumps(get_track_metrics(self.client_ip, self.client_visitorid)).encode(),
556
545
  "application/json",
557
546
  ),
558
- "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
547
+ "metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
559
548
  "file": (metadata_with_md5.name, file, "application/octet-stream"),
560
549
  }
561
550
  if search_customization is not None:
562
551
  files["customization"] = (
563
552
  "customization.json",
564
- search_customization.json(exclude_none=True).encode(),
553
+ search_customization.model_dump_json(exclude_none=True).encode(),
565
554
  "application/json",
566
555
  )
567
556
 
@@ -585,11 +574,11 @@ class _RestClient:
585
574
  ) -> SearchTaskResponse:
586
575
  api_path = self.VALIDATION_SEARCH_WITHOUT_UPLOAD_URI_FMT_V2.format(file_upload_id, initial_search_task_id)
587
576
  files = {
588
- "metadata": ("metadata.json", metadata.json(exclude_none=True).encode(), "application/json"),
589
- "metrics": ("metrics.json", metrics.json(exclude_none=True).encode(), "application/json"),
577
+ "metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
578
+ "metrics": ("metrics.json", metrics.model_dump_json(exclude_none=True).encode(), "application/json"),
590
579
  }
591
580
  if search_customization is not None:
592
- files["customization"] = search_customization.json(exclude_none=True).encode()
581
+ files["customization"] = search_customization.model_dump_json(exclude_none=True).encode()
593
582
  additional_headers = {self.SEARCH_KEYS_HEADER_NAME: ",".join(self.search_keys_meaning_types(metadata))}
594
583
  response = self._with_unauth_retry(
595
584
  lambda: self._send_post_file_req_v2(
@@ -651,7 +640,7 @@ class _RestClient:
651
640
  with open(file_path, "rb") as file:
652
641
  files = {
653
642
  "file": (metadata.name, file, "application/octet-stream"),
654
- "metadata": ("metadata.json", metadata.json(exclude_none=True).encode(), "application/json"),
643
+ "metadata": ("metadata.json", metadata.model_dump_json(exclude_none=True).encode(), "application/json"),
655
644
  }
656
645
 
657
646
  return self._send_post_file_req_v2(api_path, files)
@@ -661,12 +650,12 @@ class _RestClient:
661
650
  def get_search_file_metadata(self, search_task_id: str, trace_id: str) -> FileMetadata:
662
651
  api_path = self.SEARCH_FILE_METADATA_URI_FMT_V2.format(search_task_id)
663
652
  response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
664
- return FileMetadata.parse_obj(response)
653
+ return FileMetadata.model_validate(response)
665
654
 
666
655
  def get_provider_search_metadata_v3(self, provider_search_task_id: str, trace_id: str) -> ProviderTaskMetadataV2:
667
656
  api_path = self.SEARCH_TASK_METADATA_FMT_V3.format(provider_search_task_id)
668
657
  response = self._with_unauth_retry(lambda: self._send_get_req(api_path, trace_id))
669
- return ProviderTaskMetadataV2.parse_obj(response)
658
+ return ProviderTaskMetadataV2.model_validate(response)
670
659
 
671
660
  def get_current_transform_usage(self, trace_id) -> TransformUsage:
672
661
  track_metrics = get_track_metrics(self.client_ip, self.client_visitorid)
upgini/lazy_import.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import importlib
2
+ import importlib.util
3
+ import importlib.machinery
2
4
 
3
5
 
4
6
  class LazyImport:
@@ -10,7 +12,18 @@ class LazyImport:
10
12
 
11
13
  def _load(self):
12
14
  if self._module is None:
13
- self._module = importlib.import_module(self.module_name)
15
+ # Load module and save link to it
16
+ spec = importlib.util.find_spec(self.module_name)
17
+ if spec is None:
18
+ raise ImportError(f"Module {self.module_name} not found")
19
+
20
+ # Create module
21
+ self._module = importlib.util.module_from_spec(spec)
22
+
23
+ # Execute module
24
+ spec.loader.exec_module(self._module)
25
+
26
+ # Get class from module
14
27
  self._class = getattr(self._module, self.class_name)
15
28
 
16
29
  def __call__(self, *args, **kwargs):
upgini/metadata.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from enum import Enum
4
- from typing import Dict, List, Optional, Set
4
+ from typing import Any, Dict, List, Optional, Set, Union
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
@@ -113,6 +113,21 @@ class SearchKey(Enum):
113
113
  if meaning_type == FileColumnMeaningType.MSISDN_RANGE_TO:
114
114
  return SearchKey.MSISDN_RANGE_TO
115
115
 
116
+ @staticmethod
117
+ def find_key(search_keys: Dict[str, SearchKey], keys: Union[SearchKey, List[SearchKey]]) -> Optional[SearchKey]:
118
+ if isinstance(keys, SearchKey):
119
+ keys = [keys]
120
+ for col, key_type in search_keys.items():
121
+ if key_type in keys:
122
+ return col
123
+ return None
124
+
125
+ @staticmethod
126
+ def find_all_keys(search_keys: Dict[str, SearchKey], keys: Union[SearchKey, List[SearchKey]]) -> List[SearchKey]:
127
+ if isinstance(keys, SearchKey):
128
+ keys = [keys]
129
+ return [col for col, key_type in search_keys.items() if key_type in keys]
130
+
116
131
 
117
132
  class DataType(Enum):
118
133
  INT = "INT"
@@ -157,23 +172,23 @@ class FileMetricsInterval(BaseModel):
157
172
  date_cut: float
158
173
  count: float
159
174
  valid_count: float
160
- avg_target: Optional[float] # not for multiclass
161
- avg_score_etalon: Optional[float]
175
+ avg_target: Optional[float] = None # not for multiclass
176
+ avg_score_etalon: Optional[float] = None
162
177
 
163
178
 
164
179
  class FileMetrics(BaseModel):
165
180
  # etalon metadata
166
- task_type: Optional[ModelTaskType]
167
- label: Optional[ModelLabelType]
168
- count: Optional[int]
169
- valid_count: Optional[int]
170
- valid_rate: Optional[float]
171
- avg_target: Optional[float]
172
- metrics_binary_etalon: Optional[BinaryTask]
173
- metrics_regression_etalon: Optional[RegressionTask]
174
- metrics_multiclass_etalon: Optional[MulticlassTask]
175
- cuts: Optional[List[float]]
176
- interval: Optional[List[FileMetricsInterval]]
181
+ task_type: Optional[ModelTaskType] = None
182
+ label: Optional[ModelLabelType] = None
183
+ count: Optional[int] = None
184
+ valid_count: Optional[int] = None
185
+ valid_rate: Optional[float] = None
186
+ avg_target: Optional[float] = None
187
+ metrics_binary_etalon: Optional[BinaryTask] = None
188
+ metrics_regression_etalon: Optional[RegressionTask] = None
189
+ metrics_multiclass_etalon: Optional[MulticlassTask] = None
190
+ cuts: Optional[List[float]] = None
191
+ interval: Optional[List[FileMetricsInterval]] = None
177
192
 
178
193
 
179
194
  class NumericInterval(BaseModel):
@@ -187,25 +202,25 @@ class FileColumnMetadata(BaseModel):
187
202
  dataType: DataType
188
203
  meaningType: FileColumnMeaningType
189
204
  minMaxValues: Optional[NumericInterval] = None
190
- originalName: Optional[str]
205
+ originalName: Optional[str] = None
191
206
  # is this column contains keys from multiple key columns like msisdn1, msisdn2
192
207
  isUnnest: bool = False
193
208
  # list of original etalon key column names like msisdn1, msisdn2
194
- unnestKeyNames: Optional[List[str]]
209
+ unnestKeyNames: Optional[List[str]] = None
195
210
 
196
211
 
197
212
  class FileMetadata(BaseModel):
198
213
  name: str
199
- description: Optional[str]
214
+ description: Optional[str] = None
200
215
  columns: List[FileColumnMetadata]
201
216
  searchKeys: List[List[str]]
202
- excludeFeaturesSources: Optional[List[str]]
203
- hierarchicalGroupKeys: Optional[List[str]]
204
- hierarchicalSubgroupKeys: Optional[List[str]]
205
- taskType: Optional[ModelTaskType]
206
- rowsCount: Optional[int]
207
- checksumMD5: Optional[str]
208
- digest: Optional[str]
217
+ excludeFeaturesSources: Optional[List[str]] = None
218
+ hierarchicalGroupKeys: Optional[List[str]] = None
219
+ hierarchicalSubgroupKeys: Optional[List[str]] = None
220
+ taskType: Optional[ModelTaskType] = None
221
+ rowsCount: Optional[int] = None
222
+ checksumMD5: Optional[str] = None
223
+ digest: Optional[str] = None
209
224
 
210
225
  def column_by_name(self, name: str) -> Optional[FileColumnMetadata]:
211
226
  for c in self.columns:
@@ -229,17 +244,17 @@ class FeaturesMetadataV2(BaseModel):
229
244
  source: str
230
245
  hit_rate: float
231
246
  shap_value: float
232
- commercial_schema: Optional[str]
233
- data_provider: Optional[str]
234
- data_providers: Optional[List[str]]
235
- data_provider_link: Optional[str]
236
- data_provider_links: Optional[List[str]]
237
- data_source: Optional[str]
238
- data_sources: Optional[List[str]]
239
- data_source_link: Optional[str]
240
- data_source_links: Optional[List[str]]
241
- doc_link: Optional[str]
242
- update_frequency: Optional[str]
247
+ commercial_schema: Optional[str] = None
248
+ data_provider: Optional[str] = None
249
+ data_providers: Optional[List[str]] = None
250
+ data_provider_link: Optional[str] = None
251
+ data_provider_links: Optional[List[str]] = None
252
+ data_source: Optional[str] = None
253
+ data_sources: Optional[List[str]] = None
254
+ data_source_link: Optional[str] = None
255
+ data_source_links: Optional[List[str]] = None
256
+ doc_link: Optional[str] = None
257
+ update_frequency: Optional[str] = None
243
258
 
244
259
 
245
260
  class HitRateMetrics(BaseModel):
@@ -259,48 +274,48 @@ class ModelEvalSet(BaseModel):
259
274
  class BaseColumnMetadata(BaseModel):
260
275
  original_name: str
261
276
  hashed_name: str
262
- ads_definition_id: Optional[str]
277
+ ads_definition_id: Optional[str] = None
263
278
  is_augmented: bool
264
279
 
265
280
 
266
281
  class GeneratedFeatureMetadata(BaseModel):
267
- alias: Optional[str]
282
+ alias: Optional[str] = None
268
283
  formula: str
269
284
  display_index: str
270
285
  base_columns: List[BaseColumnMetadata]
271
- operator_params: Optional[Dict[str, str]]
286
+ operator_params: Optional[Dict[str, str]] = None
272
287
 
273
288
 
274
289
  class ProviderTaskMetadataV2(BaseModel):
275
290
  features: List[FeaturesMetadataV2]
276
- hit_rate_metrics: Optional[HitRateMetrics]
277
- eval_set_metrics: Optional[List[ModelEvalSet]]
278
- zero_hit_rate_search_keys: Optional[List[str]]
279
- features_used_for_embeddings: Optional[List[str]]
280
- shuffle_kfold: Optional[bool]
281
- generated_features: Optional[List[GeneratedFeatureMetadata]]
291
+ hit_rate_metrics: Optional[HitRateMetrics] = None
292
+ eval_set_metrics: Optional[List[ModelEvalSet]] = None
293
+ zero_hit_rate_search_keys: Optional[List[str]] = None
294
+ features_used_for_embeddings: Optional[List[str]] = None
295
+ shuffle_kfold: Optional[bool] = None
296
+ generated_features: Optional[List[GeneratedFeatureMetadata]] = None
282
297
 
283
298
 
284
299
  class FeaturesFilter(BaseModel):
285
- minImportance: Optional[float]
286
- maxPSI: Optional[float]
287
- maxCount: Optional[int]
288
- selectedFeatures: Optional[List[str]]
300
+ minImportance: Optional[float] = None
301
+ maxPSI: Optional[float] = None
302
+ maxCount: Optional[int] = None
303
+ selectedFeatures: Optional[List[str]] = None
289
304
 
290
305
 
291
306
  class RuntimeParameters(BaseModel):
292
- properties: Dict[str, str] = {}
307
+ properties: Dict[str, Any] = {}
293
308
 
294
309
 
295
310
  class SearchCustomization(BaseModel):
296
- featuresFilter: Optional[FeaturesFilter]
297
- extractFeatures: Optional[bool]
298
- accurateModel: Optional[bool]
299
- importanceThreshold: Optional[float]
300
- maxFeatures: Optional[int]
301
- returnScores: Optional[bool]
302
- runtimeParameters: Optional[RuntimeParameters]
303
- metricsCalculation: Optional[bool]
311
+ featuresFilter: Optional[FeaturesFilter] = None
312
+ extractFeatures: Optional[bool] = None
313
+ accurateModel: Optional[bool] = None
314
+ importanceThreshold: Optional[float] = None
315
+ maxFeatures: Optional[int] = None
316
+ returnScores: Optional[bool] = None
317
+ runtimeParameters: Optional[RuntimeParameters] = None
318
+ metricsCalculation: Optional[bool] = None
304
319
 
305
320
  def __repr__(self):
306
321
  return (
@@ -0,0 +1,202 @@
1
+ import hashlib
2
+ from logging import Logger, getLogger
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pandas.api.types import is_bool_dtype as is_bool
8
+ from pandas.api.types import is_datetime64_any_dtype as is_datetime
9
+ from pandas.api.types import (
10
+ is_float_dtype,
11
+ is_numeric_dtype,
12
+ is_object_dtype,
13
+ is_string_dtype,
14
+ )
15
+
16
+ from upgini.errors import ValidationError
17
+ from upgini.metadata import (
18
+ ENTITY_SYSTEM_RECORD_ID,
19
+ EVAL_SET_INDEX,
20
+ SEARCH_KEY_UNNEST,
21
+ SYSTEM_RECORD_ID,
22
+ TARGET,
23
+ SearchKey,
24
+ )
25
+ from upgini.resource_bundle import ResourceBundle, get_custom_bundle
26
+ from upgini.utils import find_numbers_with_decimal_comma
27
+ from upgini.utils.datetime_utils import DateTimeSearchKeyConverter
28
+ from upgini.utils.phone_utils import PhoneSearchKeyConverter
29
+ from upgini.utils.warning_counter import WarningCounter
30
+
31
+
32
+ class Normalizer:
33
+
34
+ MAX_STRING_FEATURE_LENGTH = 24573
35
+
36
+ def __init__(
37
+ self,
38
+ search_keys: Dict[str, SearchKey],
39
+ generated_features: List[str],
40
+ bundle: ResourceBundle = None,
41
+ logger: Logger = None,
42
+ warnings_counter: WarningCounter = None,
43
+ silent_mode=False,
44
+ ):
45
+ self.search_keys = search_keys
46
+ self.generated_features = generated_features
47
+ self.bundle = bundle or get_custom_bundle()
48
+ self.logger = logger or getLogger()
49
+ self.warnings_counter = warnings_counter or WarningCounter()
50
+ self.silent_mode = silent_mode
51
+ self.columns_renaming = {}
52
+
53
+ def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
54
+ df = df.copy()
55
+ df = self._rename_columns(df)
56
+
57
+ df = self._remove_dates_from_features(df)
58
+
59
+ df = self._cut_too_long_string_values(df)
60
+
61
+ df = self._convert_bools(df)
62
+
63
+ df = self._convert_float16(df)
64
+
65
+ df = self._correct_decimal_comma(df)
66
+
67
+ df = self._convert_phone_numbers(df)
68
+
69
+ df = self.__convert_features_types(df)
70
+
71
+ return df
72
+
73
+ def _rename_columns(self, df: pd.DataFrame):
74
+ # logger.info("Replace restricted symbols in column names")
75
+ new_columns = []
76
+ dup_counter = 0
77
+ for column in df.columns:
78
+ if column in [
79
+ TARGET,
80
+ EVAL_SET_INDEX,
81
+ SYSTEM_RECORD_ID,
82
+ ENTITY_SYSTEM_RECORD_ID,
83
+ SEARCH_KEY_UNNEST,
84
+ DateTimeSearchKeyConverter.DATETIME_COL,
85
+ ] + self.generated_features:
86
+ self.columns_renaming[column] = column
87
+ new_columns.append(column)
88
+ continue
89
+
90
+ new_column = str(column)
91
+ suffix = hashlib.sha256(new_column.encode()).hexdigest()[:6]
92
+ if len(new_column) == 0:
93
+ raise ValidationError(self.bundle.get("dataset_empty_column_names"))
94
+ # db limit for column length
95
+ if len(new_column) > 250:
96
+ new_column = new_column[:250]
97
+
98
+ # make column name unique relative to server features
99
+ new_column = f"{new_column}_{suffix}"
100
+
101
+ new_column = new_column.lower()
102
+
103
+ # if column starts with non alphabetic symbol then add "a" to the beginning of string
104
+ if ord(new_column[0]) not in range(ord("a"), ord("z") + 1):
105
+ new_column = "a" + new_column
106
+
107
+ # replace unsupported characters to "_"
108
+ for idx, c in enumerate(new_column):
109
+ if ord(c) not in range(ord("a"), ord("z") + 1) and ord(c) not in range(ord("0"), ord("9") + 1):
110
+ new_column = new_column[:idx] + "_" + new_column[idx + 1 :]
111
+
112
+ if new_column in new_columns:
113
+ new_column = f"{new_column}_{dup_counter}"
114
+ dup_counter += 1
115
+ new_columns.append(new_column)
116
+
117
+ # df.columns.values[col_idx] = new_column
118
+ # rename(columns={column: new_column}, inplace=True)
119
+
120
+ if new_column != column and column in self.search_keys:
121
+ self.search_keys[new_column] = self.search_keys[column]
122
+ del self.search_keys[column]
123
+ self.columns_renaming[new_column] = str(column)
124
+ df.columns = new_columns
125
+ return df
126
+
127
+ def _get_features(self, df: pd.DataFrame) -> List[str]:
128
+ system_columns = [ENTITY_SYSTEM_RECORD_ID, EVAL_SET_INDEX, SEARCH_KEY_UNNEST, SYSTEM_RECORD_ID, TARGET]
129
+ features = set(df.columns) - set(self.search_keys.keys()) - set(system_columns)
130
+ return sorted(list(features))
131
+
132
+ def _remove_dates_from_features(self, df: pd.DataFrame):
133
+ features = self._get_features(df)
134
+
135
+ removed_features = []
136
+ for f in features:
137
+ if is_datetime(df[f]) or isinstance(df[f].dtype, pd.PeriodDtype):
138
+ removed_features.append(f)
139
+ df.drop(columns=f, inplace=True)
140
+
141
+ if removed_features:
142
+ msg = self.bundle.get("dataset_date_features").format(removed_features)
143
+ self.logger.warning(msg)
144
+ if not self.silent_mode:
145
+ print(msg)
146
+ self.warnings_counter.increment()
147
+
148
+ return df
149
+
150
+ def _cut_too_long_string_values(self, df: pd.DataFrame):
151
+ """Check that string values less than maximum characters for LLM"""
152
+ # logger.info("Validate too long string values")
153
+ for col in df.columns:
154
+ if is_string_dtype(df[col]) or is_object_dtype(df[col]):
155
+ max_length: int = df[col].astype("str").str.len().max()
156
+ if max_length > self.MAX_STRING_FEATURE_LENGTH:
157
+ df[col] = df[col].astype("str").str.slice(stop=self.MAX_STRING_FEATURE_LENGTH)
158
+
159
+ return df
160
+
161
+ @staticmethod
162
+ def _convert_bools(df: pd.DataFrame):
163
+ """Convert bool columns to string"""
164
+ # logger.info("Converting bool to int")
165
+ for col in df.columns:
166
+ if is_bool(df[col]):
167
+ df[col] = df[col].astype("str")
168
+ return df
169
+
170
+ @staticmethod
171
+ def _convert_float16(df: pd.DataFrame):
172
+ """Convert float16 to float"""
173
+ # logger.info("Converting float16 to float")
174
+ for col in df.columns:
175
+ if is_float_dtype(df[col]):
176
+ df[col] = df[col].astype("float64")
177
+ return df
178
+
179
+ def _correct_decimal_comma(self, df: pd.DataFrame):
180
+ """Check DataSet for decimal commas and fix them"""
181
+ # logger.info("Correct decimal commas")
182
+ columns_to_fix = find_numbers_with_decimal_comma(df)
183
+ if len(columns_to_fix) > 0:
184
+ self.logger.warning(f"Convert strings with decimal comma to float: {columns_to_fix}")
185
+ for col in columns_to_fix:
186
+ df[col] = df[col].astype("string").str.replace(",", ".", regex=False).astype(np.float64)
187
+ return df
188
+
189
+ def _convert_phone_numbers(self, df: pd.DataFrame) -> pd.DataFrame:
190
+ maybe_country_col = SearchKey.find_key(self.search_keys, SearchKey.COUNTRY)
191
+ for phone_col in SearchKey.find_all_keys(self.search_keys, SearchKey.PHONE):
192
+ converter = PhoneSearchKeyConverter(phone_col, maybe_country_col)
193
+ df = converter.convert(df)
194
+ return df
195
+
196
+ def __convert_features_types(self, df: pd.DataFrame):
197
+ # self.logger.info("Convert features to supported data types")
198
+
199
+ for f in self._get_features(df):
200
+ if not is_numeric_dtype(df[f]):
201
+ df[f] = df[f].astype("string")
202
+ return df
@@ -4,6 +4,22 @@ from pandas.api.types import is_object_dtype, is_string_dtype
4
4
  from upgini.utils.base_search_key_detector import BaseSearchKeyDetector
5
5
 
6
6
 
7
+ class CountrySearchKeyConverter:
8
+
9
+ def __init__(self, country_col: str):
10
+ self.country_col = country_col
11
+
12
+ def convert(self, df: pd.DataFrame) -> pd.DataFrame:
13
+ df[self.country_col] = (
14
+ df[self.country_col]
15
+ .astype("string")
16
+ .str.upper()
17
+ .str.replace(r"[^A-Z]", "", regex=True)
18
+ .str.replace("UK", "GB", regex=False)
19
+ )
20
+ return df
21
+
22
+
7
23
  class CountrySearchKeyDetector(BaseSearchKeyDetector):
8
24
  def _is_search_key_by_name(self, column_name: str) -> bool:
9
25
  return "country" in str(column_name).lower()