additory 0.1.0a4__py3-none-any.whl → 0.1.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- additory/__init__.py +58 -14
- additory/common/__init__.py +31 -147
- additory/common/column_selector.py +255 -0
- additory/common/distributions.py +286 -613
- additory/common/extractors.py +313 -0
- additory/common/knn_imputation.py +332 -0
- additory/common/result.py +380 -0
- additory/common/strategy_parser.py +243 -0
- additory/common/unit_conversions.py +338 -0
- additory/common/validation.py +283 -103
- additory/core/__init__.py +34 -22
- additory/core/backend.py +258 -0
- additory/core/config.py +177 -305
- additory/core/logging.py +230 -24
- additory/core/memory_manager.py +157 -495
- additory/expressions/__init__.py +2 -23
- additory/expressions/compiler.py +457 -0
- additory/expressions/engine.py +264 -487
- additory/expressions/integrity.py +179 -0
- additory/expressions/loader.py +263 -0
- additory/expressions/parser.py +363 -167
- additory/expressions/resolver.py +274 -0
- additory/functions/__init__.py +1 -0
- additory/functions/analyze/__init__.py +144 -0
- additory/functions/analyze/cardinality.py +58 -0
- additory/functions/analyze/correlations.py +66 -0
- additory/functions/analyze/distributions.py +53 -0
- additory/functions/analyze/duplicates.py +49 -0
- additory/functions/analyze/features.py +61 -0
- additory/functions/analyze/imputation.py +66 -0
- additory/functions/analyze/outliers.py +65 -0
- additory/functions/analyze/patterns.py +65 -0
- additory/functions/analyze/presets.py +72 -0
- additory/functions/analyze/quality.py +59 -0
- additory/functions/analyze/timeseries.py +53 -0
- additory/functions/analyze/types.py +45 -0
- additory/functions/expressions/__init__.py +161 -0
- additory/functions/snapshot/__init__.py +82 -0
- additory/functions/snapshot/filter.py +119 -0
- additory/functions/synthetic/__init__.py +113 -0
- additory/functions/synthetic/mode_detector.py +47 -0
- additory/functions/synthetic/strategies/__init__.py +1 -0
- additory/functions/synthetic/strategies/advanced.py +35 -0
- additory/functions/synthetic/strategies/augmentative.py +160 -0
- additory/functions/synthetic/strategies/generative.py +168 -0
- additory/functions/synthetic/strategies/presets.py +116 -0
- additory/functions/to/__init__.py +188 -0
- additory/functions/to/lookup.py +351 -0
- additory/functions/to/merge.py +189 -0
- additory/functions/to/sort.py +91 -0
- additory/functions/to/summarize.py +170 -0
- additory/functions/transform/__init__.py +140 -0
- additory/functions/transform/datetime.py +79 -0
- additory/functions/transform/extract.py +85 -0
- additory/functions/transform/harmonize.py +105 -0
- additory/functions/transform/knn.py +62 -0
- additory/functions/transform/onehotencoding.py +68 -0
- additory/functions/transform/transpose.py +42 -0
- additory-0.1.1a1.dist-info/METADATA +83 -0
- additory-0.1.1a1.dist-info/RECORD +62 -0
- additory/analysis/__init__.py +0 -48
- additory/analysis/cardinality.py +0 -126
- additory/analysis/correlations.py +0 -124
- additory/analysis/distributions.py +0 -376
- additory/analysis/quality.py +0 -158
- additory/analysis/scan.py +0 -400
- additory/common/backend.py +0 -371
- additory/common/column_utils.py +0 -191
- additory/common/exceptions.py +0 -62
- additory/common/lists.py +0 -229
- additory/common/patterns.py +0 -240
- additory/common/resolver.py +0 -567
- additory/common/sample_data.py +0 -182
- additory/core/ast_builder.py +0 -165
- additory/core/backends/__init__.py +0 -23
- additory/core/backends/arrow_bridge.py +0 -483
- additory/core/backends/cudf_bridge.py +0 -355
- additory/core/column_positioning.py +0 -358
- additory/core/compiler_polars.py +0 -166
- additory/core/enhanced_cache_manager.py +0 -1119
- additory/core/enhanced_matchers.py +0 -473
- additory/core/enhanced_version_manager.py +0 -325
- additory/core/executor.py +0 -59
- additory/core/integrity_manager.py +0 -477
- additory/core/loader.py +0 -190
- additory/core/namespace_manager.py +0 -657
- additory/core/parser.py +0 -176
- additory/core/polars_expression_engine.py +0 -601
- additory/core/registry.py +0 -177
- additory/core/sample_data_manager.py +0 -492
- additory/core/user_namespace.py +0 -751
- additory/core/validator.py +0 -27
- additory/dynamic_api.py +0 -352
- additory/expressions/proxy.py +0 -549
- additory/expressions/registry.py +0 -313
- additory/expressions/samples.py +0 -492
- additory/synthetic/__init__.py +0 -13
- additory/synthetic/column_name_resolver.py +0 -149
- additory/synthetic/deduce.py +0 -259
- additory/synthetic/distributions.py +0 -22
- additory/synthetic/forecast.py +0 -1132
- additory/synthetic/linked_list_parser.py +0 -415
- additory/synthetic/namespace_lookup.py +0 -129
- additory/synthetic/smote.py +0 -320
- additory/synthetic/strategies.py +0 -926
- additory/synthetic/synthesizer.py +0 -713
- additory/utilities/__init__.py +0 -53
- additory/utilities/encoding.py +0 -600
- additory/utilities/games.py +0 -300
- additory/utilities/keys.py +0 -8
- additory/utilities/lookup.py +0 -103
- additory/utilities/matchers.py +0 -216
- additory/utilities/resolvers.py +0 -286
- additory/utilities/settings.py +0 -167
- additory/utilities/units.py +0 -749
- additory/utilities/validators.py +0 -153
- additory-0.1.0a4.dist-info/METADATA +0 -311
- additory-0.1.0a4.dist-info/RECORD +0 -72
- additory-0.1.0a4.dist-info/licenses/LICENSE +0 -21
- {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/WHEEL +0 -0
- {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feature extraction utilities for Additory.
|
|
3
|
+
|
|
4
|
+
Provides functions to extract features from various column types
|
|
5
|
+
(datetime, email, text, URL, phone, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from typing import Dict, List, Optional
|
|
10
|
+
import polars as pl
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extract_datetime_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
|
|
14
|
+
"""
|
|
15
|
+
Extract datetime features from a datetime column.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
series: Polars Series with datetime values
|
|
19
|
+
features: List of features to extract
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Dictionary mapping feature name to Series
|
|
23
|
+
|
|
24
|
+
Supported Features:
|
|
25
|
+
- 'year', 'month', 'day', 'hour', 'minute', 'second'
|
|
26
|
+
- 'day_of_week', 'day_of_year', 'week', 'quarter'
|
|
27
|
+
- 'time_of_day' ('morning', 'afternoon', 'evening', 'night')
|
|
28
|
+
- 'is_weekend', 'is_business_day'
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
features = extract_datetime_features(df['timestamp'], ['hour', 'day_of_week'])
|
|
32
|
+
"""
|
|
33
|
+
result = {}
|
|
34
|
+
|
|
35
|
+
for feature in features:
|
|
36
|
+
if feature == 'year':
|
|
37
|
+
result[feature] = series.dt.year()
|
|
38
|
+
elif feature == 'month':
|
|
39
|
+
result[feature] = series.dt.month()
|
|
40
|
+
elif feature == 'day':
|
|
41
|
+
result[feature] = series.dt.day()
|
|
42
|
+
elif feature == 'hour':
|
|
43
|
+
result[feature] = series.dt.hour()
|
|
44
|
+
elif feature == 'minute':
|
|
45
|
+
result[feature] = series.dt.minute()
|
|
46
|
+
elif feature == 'second':
|
|
47
|
+
result[feature] = series.dt.second()
|
|
48
|
+
elif feature == 'day_of_week':
|
|
49
|
+
result[feature] = series.dt.weekday()
|
|
50
|
+
elif feature == 'day_of_year':
|
|
51
|
+
result[feature] = series.dt.ordinal_day()
|
|
52
|
+
elif feature == 'week':
|
|
53
|
+
result[feature] = series.dt.week()
|
|
54
|
+
elif feature == 'quarter':
|
|
55
|
+
result[feature] = series.dt.quarter()
|
|
56
|
+
elif feature == 'time_of_day':
|
|
57
|
+
hour = series.dt.hour()
|
|
58
|
+
time_of_day = pl.when(hour < 6).then(pl.lit('night')) \
|
|
59
|
+
.when(hour < 12).then(pl.lit('morning')) \
|
|
60
|
+
.when(hour < 18).then(pl.lit('afternoon')) \
|
|
61
|
+
.otherwise(pl.lit('evening'))
|
|
62
|
+
# Evaluate the expression to get a Series
|
|
63
|
+
result[feature] = pl.select(time_of_day.alias('time_of_day')).to_series()
|
|
64
|
+
elif feature == 'is_weekend':
|
|
65
|
+
weekday = series.dt.weekday()
|
|
66
|
+
result[feature] = (weekday >= 6)
|
|
67
|
+
elif feature == 'is_business_day':
|
|
68
|
+
weekday = series.dt.weekday()
|
|
69
|
+
result[feature] = (weekday < 5)
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Unsupported datetime feature: {feature}")
|
|
72
|
+
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def extract_email_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
|
|
77
|
+
"""
|
|
78
|
+
Extract features from email addresses.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
series: Polars Series with email strings
|
|
82
|
+
features: List of features to extract
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Dictionary mapping feature name to Series
|
|
86
|
+
|
|
87
|
+
Supported Features:
|
|
88
|
+
- 'domain' - Email domain (e.g., 'gmail.com')
|
|
89
|
+
- 'username' - Username part (before @)
|
|
90
|
+
- 'tld' - Top-level domain (e.g., 'com')
|
|
91
|
+
- 'is_free_email' - Boolean (gmail, yahoo, etc.)
|
|
92
|
+
- 'is_corporate' - Boolean (not free email)
|
|
93
|
+
|
|
94
|
+
Example:
|
|
95
|
+
features = extract_email_features(df['email'], ['domain', 'is_free_email'])
|
|
96
|
+
"""
|
|
97
|
+
result = {}
|
|
98
|
+
|
|
99
|
+
# Common free email providers
|
|
100
|
+
free_providers = {'gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com',
|
|
101
|
+
'aol.com', 'icloud.com', 'mail.com', 'protonmail.com'}
|
|
102
|
+
|
|
103
|
+
for feature in features:
|
|
104
|
+
if feature == 'domain':
|
|
105
|
+
# Extract domain (part after @)
|
|
106
|
+
result[feature] = series.str.extract(r'@(.+)$', 1)
|
|
107
|
+
elif feature == 'username':
|
|
108
|
+
# Extract username (part before @)
|
|
109
|
+
result[feature] = series.str.extract(r'^([^@]+)@', 1)
|
|
110
|
+
elif feature == 'tld':
|
|
111
|
+
# Extract top-level domain (last part after .)
|
|
112
|
+
result[feature] = series.str.extract(r'\.([^.]+)$', 1)
|
|
113
|
+
elif feature == 'is_free_email':
|
|
114
|
+
# Check if domain is in free providers list
|
|
115
|
+
domain = series.str.extract(r'@(.+)$', 1)
|
|
116
|
+
result[feature] = domain.is_in(list(free_providers))
|
|
117
|
+
elif feature == 'is_corporate':
|
|
118
|
+
# Opposite of is_free_email
|
|
119
|
+
domain = series.str.extract(r'@(.+)$', 1)
|
|
120
|
+
result[feature] = ~domain.is_in(list(free_providers))
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(f"Unsupported email feature: {feature}")
|
|
123
|
+
|
|
124
|
+
return result
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def extract_text_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
|
|
128
|
+
"""
|
|
129
|
+
Extract features from text columns.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
series: Polars Series with text strings
|
|
133
|
+
features: List of features to extract
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Dictionary mapping feature name to Series
|
|
137
|
+
|
|
138
|
+
Supported Features:
|
|
139
|
+
- 'length' - Character count
|
|
140
|
+
- 'word_count' - Number of words
|
|
141
|
+
- 'sentence_count' - Number of sentences
|
|
142
|
+
- 'has_numbers' - Boolean
|
|
143
|
+
- 'has_special_chars' - Boolean
|
|
144
|
+
- 'is_uppercase' - Boolean
|
|
145
|
+
- 'is_lowercase' - Boolean
|
|
146
|
+
|
|
147
|
+
Example:
|
|
148
|
+
features = extract_text_features(df['description'], ['length', 'word_count'])
|
|
149
|
+
"""
|
|
150
|
+
result = {}
|
|
151
|
+
|
|
152
|
+
for feature in features:
|
|
153
|
+
if feature == 'length':
|
|
154
|
+
result[feature] = series.str.len_chars()
|
|
155
|
+
elif feature == 'word_count':
|
|
156
|
+
# Count words (split by whitespace)
|
|
157
|
+
result[feature] = series.str.split(' ').list.len()
|
|
158
|
+
elif feature == 'sentence_count':
|
|
159
|
+
# Count sentences (split by . ! ?)
|
|
160
|
+
result[feature] = series.str.count_matches(r'[.!?]')
|
|
161
|
+
elif feature == 'has_numbers':
|
|
162
|
+
result[feature] = series.str.contains(r'\d')
|
|
163
|
+
elif feature == 'has_special_chars':
|
|
164
|
+
result[feature] = series.str.contains(r'[^a-zA-Z0-9\s]')
|
|
165
|
+
elif feature == 'is_uppercase':
|
|
166
|
+
result[feature] = series.str.to_uppercase() == series
|
|
167
|
+
elif feature == 'is_lowercase':
|
|
168
|
+
result[feature] = series.str.to_lowercase() == series
|
|
169
|
+
else:
|
|
170
|
+
raise ValueError(f"Unsupported text feature: {feature}")
|
|
171
|
+
|
|
172
|
+
return result
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def split_column(series: pl.Series, delimiter: str, max_splits: Optional[int] = None) -> List[pl.Series]:
|
|
176
|
+
"""
|
|
177
|
+
Split column by delimiter into multiple columns.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
series: Polars Series to split
|
|
181
|
+
delimiter: Delimiter string
|
|
182
|
+
max_splits: Maximum number of splits (None = unlimited)
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
List of Series (one per split part)
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
# Split 'red,green,blue' into 3 columns
|
|
189
|
+
parts = split_column(df['tags'], delimiter=',', max_splits=3)
|
|
190
|
+
"""
|
|
191
|
+
# Split the series
|
|
192
|
+
split_series = series.str.split(delimiter)
|
|
193
|
+
|
|
194
|
+
# Determine number of parts
|
|
195
|
+
if max_splits is None:
|
|
196
|
+
# Find maximum number of parts
|
|
197
|
+
max_len = split_series.list.len().max()
|
|
198
|
+
if max_len is None:
|
|
199
|
+
# Empty series
|
|
200
|
+
return []
|
|
201
|
+
max_parts = int(max_len)
|
|
202
|
+
else:
|
|
203
|
+
max_parts = max_splits
|
|
204
|
+
|
|
205
|
+
# Extract each part as a separate series
|
|
206
|
+
result = []
|
|
207
|
+
for i in range(max_parts):
|
|
208
|
+
try:
|
|
209
|
+
part = split_series.list.get(i, null_on_oob=True)
|
|
210
|
+
result.append(part)
|
|
211
|
+
except Exception:
|
|
212
|
+
# If index is out of bounds, append None series
|
|
213
|
+
result.append(pl.Series([None] * len(series)))
|
|
214
|
+
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def extract_url_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
|
|
219
|
+
"""
|
|
220
|
+
Extract features from URLs.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
series: Polars Series with URL strings
|
|
224
|
+
features: List of features to extract
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Dictionary mapping feature name to Series
|
|
228
|
+
|
|
229
|
+
Supported Features:
|
|
230
|
+
- 'protocol' - http, https, ftp, etc.
|
|
231
|
+
- 'domain' - Domain name
|
|
232
|
+
- 'path' - URL path
|
|
233
|
+
- 'is_secure' - Boolean (https)
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
features = extract_url_features(df['url'], ['protocol', 'domain'])
|
|
237
|
+
"""
|
|
238
|
+
result = {}
|
|
239
|
+
|
|
240
|
+
for feature in features:
|
|
241
|
+
if feature == 'protocol':
|
|
242
|
+
# Extract protocol (e.g., http, https)
|
|
243
|
+
result[feature] = series.str.extract(r'^([a-z]+)://', 1)
|
|
244
|
+
elif feature == 'domain':
|
|
245
|
+
# Extract domain (between :// and /)
|
|
246
|
+
result[feature] = series.str.extract(r'://([^/]+)', 1)
|
|
247
|
+
elif feature == 'path':
|
|
248
|
+
# Extract path (after domain)
|
|
249
|
+
result[feature] = series.str.extract(r'://[^/]+(/[^?#]*)', 1)
|
|
250
|
+
elif feature == 'is_secure':
|
|
251
|
+
# Check if protocol is https
|
|
252
|
+
protocol = series.str.extract(r'^([a-z]+)://', 1)
|
|
253
|
+
result[feature] = (protocol == 'https')
|
|
254
|
+
else:
|
|
255
|
+
raise ValueError(f"Unsupported URL feature: {feature}")
|
|
256
|
+
|
|
257
|
+
return result
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def extract_phone_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
|
|
261
|
+
"""
|
|
262
|
+
Extract features from phone numbers.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
series: Polars Series with phone number strings
|
|
266
|
+
features: List of features to extract
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Dictionary mapping feature name to Series
|
|
270
|
+
|
|
271
|
+
Supported Features:
|
|
272
|
+
- 'country_code' - Country code
|
|
273
|
+
- 'area_code' - Area code
|
|
274
|
+
- 'is_valid' - Boolean (basic validation)
|
|
275
|
+
|
|
276
|
+
Example:
|
|
277
|
+
features = extract_phone_features(df['phone'], ['country_code', 'area_code'])
|
|
278
|
+
"""
|
|
279
|
+
result = {}
|
|
280
|
+
|
|
281
|
+
for feature in features:
|
|
282
|
+
if feature == 'country_code':
|
|
283
|
+
# Extract country code (e.g., +1, +44)
|
|
284
|
+
result[feature] = series.str.extract(r'^\+?(\d{1,3})', 1)
|
|
285
|
+
elif feature == 'area_code':
|
|
286
|
+
# Extract area code (3 digits after country code)
|
|
287
|
+
result[feature] = series.str.extract(r'\(?(\d{3})\)?', 1)
|
|
288
|
+
elif feature == 'is_valid':
|
|
289
|
+
# Basic validation: has at least 10 digits
|
|
290
|
+
digits_only = series.str.replace_all(r'\D', '')
|
|
291
|
+
result[feature] = digits_only.str.len_chars() >= 10
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError(f"Unsupported phone feature: {feature}")
|
|
294
|
+
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def parse_pattern(series: pl.Series, pattern: str) -> pl.Series:
|
|
299
|
+
"""
|
|
300
|
+
Extract pattern from strings using regex.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
series: Polars Series with strings
|
|
304
|
+
pattern: Regex pattern to extract
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Series with extracted values
|
|
308
|
+
|
|
309
|
+
Example:
|
|
310
|
+
# Extract numbers from strings
|
|
311
|
+
numbers = parse_pattern(df['text'], r'\\d+')
|
|
312
|
+
"""
|
|
313
|
+
return series.str.extract(pattern, 0)
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
"""
|
|
2
|
+
K-Nearest Neighbors imputation utilities for Additory.
|
|
3
|
+
|
|
4
|
+
Provides KNN-based imputation for missing values in numeric columns.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import List, Optional
|
|
9
|
+
import polars as pl
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def knn_impute(df: pl.DataFrame, columns: List[str], k: int = 5,
|
|
14
|
+
weights: str = 'distance', metric: str = 'euclidean') -> pl.DataFrame:
|
|
15
|
+
"""
|
|
16
|
+
Impute missing values using K-Nearest Neighbors.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
df: DataFrame with missing values
|
|
20
|
+
columns: Columns to impute
|
|
21
|
+
k: Number of neighbors (default: 5)
|
|
22
|
+
weights: Weighting strategy ('uniform' or 'distance')
|
|
23
|
+
metric: Distance metric ('euclidean', 'manhattan', 'cosine')
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
DataFrame with imputed values
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
imputed_df = knn_impute(df, columns=['age', 'income'], k=5, weights='distance')
|
|
30
|
+
"""
|
|
31
|
+
# Validate parameters
|
|
32
|
+
validate_knn_parameters(df, columns, k)
|
|
33
|
+
|
|
34
|
+
# Create a copy to avoid modifying original
|
|
35
|
+
result_df = df.clone()
|
|
36
|
+
|
|
37
|
+
# Get all numeric columns for distance calculation
|
|
38
|
+
all_numeric_cols = [col for col in df.columns
|
|
39
|
+
if df[col].dtype in [pl.Int8, pl.Int16, pl.Int32, pl.Int64,
|
|
40
|
+
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
|
|
41
|
+
pl.Float32, pl.Float64]]
|
|
42
|
+
|
|
43
|
+
# Convert to numpy for easier manipulation
|
|
44
|
+
data = df.select(columns).to_numpy().copy() # Make a copy we can modify
|
|
45
|
+
|
|
46
|
+
# For each row with missing values
|
|
47
|
+
for row_idx in range(len(df)):
|
|
48
|
+
row = data[row_idx]
|
|
49
|
+
|
|
50
|
+
# Check if this row has any missing values
|
|
51
|
+
if not np.any(np.isnan(row)):
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
# Check if this row has at least some non-missing values in ANY numeric column
|
|
55
|
+
# (needed for distance calculation)
|
|
56
|
+
row_all_cols = df.select(all_numeric_cols).row(row_idx)
|
|
57
|
+
has_any_non_missing = any(val is not None and not (isinstance(val, float) and np.isnan(val))
|
|
58
|
+
for val in row_all_cols)
|
|
59
|
+
if not has_any_non_missing:
|
|
60
|
+
# All numeric values missing, skip this row
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
# Calculate distances to all other rows using ALL numeric columns
|
|
64
|
+
distances = calculate_distances(df, row_idx, all_numeric_cols, metric)
|
|
65
|
+
|
|
66
|
+
# Find k nearest neighbors
|
|
67
|
+
neighbor_indices = find_k_nearest(distances, k, exclude_idx=row_idx)
|
|
68
|
+
|
|
69
|
+
# Impute each missing value
|
|
70
|
+
for col_idx, col_name in enumerate(columns):
|
|
71
|
+
if np.isnan(row[col_idx]):
|
|
72
|
+
# Get values from neighbors
|
|
73
|
+
neighbor_values = []
|
|
74
|
+
neighbor_distances = []
|
|
75
|
+
|
|
76
|
+
for neighbor_idx in neighbor_indices:
|
|
77
|
+
neighbor_value = data[neighbor_idx, col_idx]
|
|
78
|
+
if not np.isnan(neighbor_value):
|
|
79
|
+
neighbor_values.append(neighbor_value)
|
|
80
|
+
neighbor_distances.append(distances[neighbor_idx])
|
|
81
|
+
|
|
82
|
+
# If we have neighbor values, compute weighted average
|
|
83
|
+
if neighbor_values:
|
|
84
|
+
neighbor_values_series = pl.Series(neighbor_values)
|
|
85
|
+
neighbor_distances_series = pl.Series(neighbor_distances)
|
|
86
|
+
|
|
87
|
+
imputed_value = compute_weighted_average(
|
|
88
|
+
neighbor_values_series,
|
|
89
|
+
neighbor_distances_series,
|
|
90
|
+
weights
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Update the data array
|
|
94
|
+
data[row_idx, col_idx] = imputed_value
|
|
95
|
+
|
|
96
|
+
# Replace the columns in the result DataFrame with imputed data
|
|
97
|
+
for col_idx, col_name in enumerate(columns):
|
|
98
|
+
result_df = result_df.with_columns(
|
|
99
|
+
pl.Series(col_name, data[:, col_idx])
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return result_df
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def calculate_distances(df: pl.DataFrame, row_idx: int, columns: List[str],
|
|
106
|
+
metric: str) -> pl.Series:
|
|
107
|
+
"""
|
|
108
|
+
Calculate distances from a row to all other rows.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
df: DataFrame
|
|
112
|
+
row_idx: Index of row to calculate distances from
|
|
113
|
+
columns: Columns to use for distance calculation
|
|
114
|
+
metric: Distance metric
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Series of distances
|
|
118
|
+
"""
|
|
119
|
+
data = df.select(columns).to_numpy()
|
|
120
|
+
row = data[row_idx]
|
|
121
|
+
|
|
122
|
+
distances = []
|
|
123
|
+
for i in range(len(df)):
|
|
124
|
+
if i == row_idx:
|
|
125
|
+
distances.append(float('inf')) # Exclude self
|
|
126
|
+
else:
|
|
127
|
+
other_row = data[i]
|
|
128
|
+
|
|
129
|
+
# Only use non-missing values for distance calculation
|
|
130
|
+
mask = ~(np.isnan(row) | np.isnan(other_row))
|
|
131
|
+
|
|
132
|
+
if not np.any(mask):
|
|
133
|
+
# No common non-missing values
|
|
134
|
+
distances.append(float('inf'))
|
|
135
|
+
else:
|
|
136
|
+
row_clean = row[mask]
|
|
137
|
+
other_clean = other_row[mask]
|
|
138
|
+
|
|
139
|
+
if metric == 'euclidean':
|
|
140
|
+
dist = euclidean_distance(
|
|
141
|
+
pl.Series(row_clean),
|
|
142
|
+
pl.Series(other_clean)
|
|
143
|
+
)
|
|
144
|
+
elif metric == 'manhattan':
|
|
145
|
+
dist = manhattan_distance(
|
|
146
|
+
pl.Series(row_clean),
|
|
147
|
+
pl.Series(other_clean)
|
|
148
|
+
)
|
|
149
|
+
elif metric == 'cosine':
|
|
150
|
+
dist = cosine_distance(
|
|
151
|
+
pl.Series(row_clean),
|
|
152
|
+
pl.Series(other_clean)
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(f"Unsupported metric: {metric}")
|
|
156
|
+
|
|
157
|
+
distances.append(dist)
|
|
158
|
+
|
|
159
|
+
return pl.Series(distances)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def find_k_nearest(distances: pl.Series, k: int, exclude_idx: Optional[int] = None) -> List[int]:
|
|
163
|
+
"""
|
|
164
|
+
Find indices of k nearest neighbors.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
distances: Series of distances
|
|
168
|
+
k: Number of neighbors to find
|
|
169
|
+
exclude_idx: Index to exclude (the row itself)
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
List of k nearest neighbor indices
|
|
173
|
+
"""
|
|
174
|
+
# Convert to numpy for easier manipulation
|
|
175
|
+
dist_array = distances.to_numpy()
|
|
176
|
+
|
|
177
|
+
# Get indices sorted by distance
|
|
178
|
+
sorted_indices = np.argsort(dist_array)
|
|
179
|
+
|
|
180
|
+
# Filter out infinite distances and excluded index
|
|
181
|
+
valid_indices = []
|
|
182
|
+
for idx in sorted_indices:
|
|
183
|
+
if not np.isinf(dist_array[idx]):
|
|
184
|
+
if exclude_idx is None or idx != exclude_idx:
|
|
185
|
+
valid_indices.append(int(idx))
|
|
186
|
+
if len(valid_indices) >= k:
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
return valid_indices
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def compute_weighted_average(values: pl.Series, distances: pl.Series, weights: str) -> float:
|
|
193
|
+
"""
|
|
194
|
+
Compute weighted average of neighbor values.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
values: Values from neighbors
|
|
198
|
+
distances: Distances to neighbors
|
|
199
|
+
weights: Weighting strategy ('uniform' or 'distance')
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Weighted average value
|
|
203
|
+
"""
|
|
204
|
+
if weights == 'uniform':
|
|
205
|
+
# Simple average
|
|
206
|
+
return float(values.mean())
|
|
207
|
+
|
|
208
|
+
elif weights == 'distance':
|
|
209
|
+
# Inverse distance weighting
|
|
210
|
+
values_array = values.to_numpy()
|
|
211
|
+
distances_array = distances.to_numpy()
|
|
212
|
+
|
|
213
|
+
# Avoid division by zero for very small distances
|
|
214
|
+
distances_array = np.maximum(distances_array, 1e-10)
|
|
215
|
+
|
|
216
|
+
# Inverse distance weights
|
|
217
|
+
inv_distances = 1.0 / distances_array
|
|
218
|
+
|
|
219
|
+
# Weighted average
|
|
220
|
+
weighted_sum = np.sum(values_array * inv_distances)
|
|
221
|
+
weight_sum = np.sum(inv_distances)
|
|
222
|
+
|
|
223
|
+
return float(weighted_sum / weight_sum)
|
|
224
|
+
|
|
225
|
+
else:
|
|
226
|
+
raise ValueError(f"Unsupported weighting strategy: {weights}")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def euclidean_distance(row1: pl.Series, row2: pl.Series) -> float:
|
|
230
|
+
"""
|
|
231
|
+
Calculate Euclidean distance between two rows.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
row1: First row values
|
|
235
|
+
row2: Second row values
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Euclidean distance
|
|
239
|
+
"""
|
|
240
|
+
diff = row1.to_numpy() - row2.to_numpy()
|
|
241
|
+
return float(math.sqrt(np.sum(diff ** 2)))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def manhattan_distance(row1: pl.Series, row2: pl.Series) -> float:
|
|
245
|
+
"""
|
|
246
|
+
Calculate Manhattan distance between two rows.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
row1: First row values
|
|
250
|
+
row2: Second row values
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Manhattan distance
|
|
254
|
+
"""
|
|
255
|
+
diff = np.abs(row1.to_numpy() - row2.to_numpy())
|
|
256
|
+
return float(np.sum(diff))
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def cosine_distance(row1: pl.Series, row2: pl.Series) -> float:
|
|
260
|
+
"""
|
|
261
|
+
Calculate Cosine distance between two rows.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
row1: First row values
|
|
265
|
+
row2: Second row values
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Cosine distance
|
|
269
|
+
"""
|
|
270
|
+
vec1 = row1.to_numpy()
|
|
271
|
+
vec2 = row2.to_numpy()
|
|
272
|
+
|
|
273
|
+
# Calculate dot product and norms
|
|
274
|
+
dot_product = np.dot(vec1, vec2)
|
|
275
|
+
norm1 = np.linalg.norm(vec1)
|
|
276
|
+
norm2 = np.linalg.norm(vec2)
|
|
277
|
+
|
|
278
|
+
# Avoid division by zero
|
|
279
|
+
if norm1 == 0 or norm2 == 0:
|
|
280
|
+
return 1.0 # Maximum distance
|
|
281
|
+
|
|
282
|
+
# Cosine similarity
|
|
283
|
+
cosine_sim = dot_product / (norm1 * norm2)
|
|
284
|
+
|
|
285
|
+
# Cosine distance (1 - similarity)
|
|
286
|
+
return float(1.0 - cosine_sim)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def validate_knn_parameters(df: pl.DataFrame, columns: List[str], k: int) -> bool:
|
|
290
|
+
"""
|
|
291
|
+
Validate KNN imputation parameters.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
df: DataFrame
|
|
295
|
+
columns: Columns to impute
|
|
296
|
+
k: Number of neighbors
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
True if valid
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
ValueError: If parameters are invalid
|
|
303
|
+
"""
|
|
304
|
+
# Check columns exist
|
|
305
|
+
for col in columns:
|
|
306
|
+
if col not in df.columns:
|
|
307
|
+
raise ValueError(f"Column '{col}' not found in DataFrame")
|
|
308
|
+
|
|
309
|
+
# Check at least some non-missing values exist (before checking dtype)
|
|
310
|
+
# This is important because a column with all nulls has dtype Null
|
|
311
|
+
for col in columns:
|
|
312
|
+
non_null_count = df[col].null_count()
|
|
313
|
+
if non_null_count == len(df):
|
|
314
|
+
raise ValueError(f"Column '{col}' has all missing values, cannot impute")
|
|
315
|
+
|
|
316
|
+
# Check columns are numeric
|
|
317
|
+
for col in columns:
|
|
318
|
+
dtype = df[col].dtype
|
|
319
|
+
if dtype not in [pl.Int8, pl.Int16, pl.Int32, pl.Int64,
|
|
320
|
+
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
|
|
321
|
+
pl.Float32, pl.Float64]:
|
|
322
|
+
raise ValueError(f"Column '{col}' must be numeric, got {dtype}")
|
|
323
|
+
|
|
324
|
+
# Check k is positive
|
|
325
|
+
if k <= 0:
|
|
326
|
+
raise ValueError(f"k must be positive, got {k}")
|
|
327
|
+
|
|
328
|
+
# Check k is less than number of rows
|
|
329
|
+
if k >= len(df):
|
|
330
|
+
raise ValueError(f"k ({k}) must be less than number of rows ({len(df)})")
|
|
331
|
+
|
|
332
|
+
return True
|