additory 0.1.0a4__py3-none-any.whl → 0.1.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. additory/__init__.py +58 -14
  2. additory/common/__init__.py +31 -147
  3. additory/common/column_selector.py +255 -0
  4. additory/common/distributions.py +286 -613
  5. additory/common/extractors.py +313 -0
  6. additory/common/knn_imputation.py +332 -0
  7. additory/common/result.py +380 -0
  8. additory/common/strategy_parser.py +243 -0
  9. additory/common/unit_conversions.py +338 -0
  10. additory/common/validation.py +283 -103
  11. additory/core/__init__.py +34 -22
  12. additory/core/backend.py +258 -0
  13. additory/core/config.py +177 -305
  14. additory/core/logging.py +230 -24
  15. additory/core/memory_manager.py +157 -495
  16. additory/expressions/__init__.py +2 -23
  17. additory/expressions/compiler.py +457 -0
  18. additory/expressions/engine.py +264 -487
  19. additory/expressions/integrity.py +179 -0
  20. additory/expressions/loader.py +263 -0
  21. additory/expressions/parser.py +363 -167
  22. additory/expressions/resolver.py +274 -0
  23. additory/functions/__init__.py +1 -0
  24. additory/functions/analyze/__init__.py +144 -0
  25. additory/functions/analyze/cardinality.py +58 -0
  26. additory/functions/analyze/correlations.py +66 -0
  27. additory/functions/analyze/distributions.py +53 -0
  28. additory/functions/analyze/duplicates.py +49 -0
  29. additory/functions/analyze/features.py +61 -0
  30. additory/functions/analyze/imputation.py +66 -0
  31. additory/functions/analyze/outliers.py +65 -0
  32. additory/functions/analyze/patterns.py +65 -0
  33. additory/functions/analyze/presets.py +72 -0
  34. additory/functions/analyze/quality.py +59 -0
  35. additory/functions/analyze/timeseries.py +53 -0
  36. additory/functions/analyze/types.py +45 -0
  37. additory/functions/expressions/__init__.py +161 -0
  38. additory/functions/snapshot/__init__.py +82 -0
  39. additory/functions/snapshot/filter.py +119 -0
  40. additory/functions/synthetic/__init__.py +113 -0
  41. additory/functions/synthetic/mode_detector.py +47 -0
  42. additory/functions/synthetic/strategies/__init__.py +1 -0
  43. additory/functions/synthetic/strategies/advanced.py +35 -0
  44. additory/functions/synthetic/strategies/augmentative.py +160 -0
  45. additory/functions/synthetic/strategies/generative.py +168 -0
  46. additory/functions/synthetic/strategies/presets.py +116 -0
  47. additory/functions/to/__init__.py +188 -0
  48. additory/functions/to/lookup.py +351 -0
  49. additory/functions/to/merge.py +189 -0
  50. additory/functions/to/sort.py +91 -0
  51. additory/functions/to/summarize.py +170 -0
  52. additory/functions/transform/__init__.py +140 -0
  53. additory/functions/transform/datetime.py +79 -0
  54. additory/functions/transform/extract.py +85 -0
  55. additory/functions/transform/harmonize.py +105 -0
  56. additory/functions/transform/knn.py +62 -0
  57. additory/functions/transform/onehotencoding.py +68 -0
  58. additory/functions/transform/transpose.py +42 -0
  59. additory-0.1.1a1.dist-info/METADATA +83 -0
  60. additory-0.1.1a1.dist-info/RECORD +62 -0
  61. additory/analysis/__init__.py +0 -48
  62. additory/analysis/cardinality.py +0 -126
  63. additory/analysis/correlations.py +0 -124
  64. additory/analysis/distributions.py +0 -376
  65. additory/analysis/quality.py +0 -158
  66. additory/analysis/scan.py +0 -400
  67. additory/common/backend.py +0 -371
  68. additory/common/column_utils.py +0 -191
  69. additory/common/exceptions.py +0 -62
  70. additory/common/lists.py +0 -229
  71. additory/common/patterns.py +0 -240
  72. additory/common/resolver.py +0 -567
  73. additory/common/sample_data.py +0 -182
  74. additory/core/ast_builder.py +0 -165
  75. additory/core/backends/__init__.py +0 -23
  76. additory/core/backends/arrow_bridge.py +0 -483
  77. additory/core/backends/cudf_bridge.py +0 -355
  78. additory/core/column_positioning.py +0 -358
  79. additory/core/compiler_polars.py +0 -166
  80. additory/core/enhanced_cache_manager.py +0 -1119
  81. additory/core/enhanced_matchers.py +0 -473
  82. additory/core/enhanced_version_manager.py +0 -325
  83. additory/core/executor.py +0 -59
  84. additory/core/integrity_manager.py +0 -477
  85. additory/core/loader.py +0 -190
  86. additory/core/namespace_manager.py +0 -657
  87. additory/core/parser.py +0 -176
  88. additory/core/polars_expression_engine.py +0 -601
  89. additory/core/registry.py +0 -177
  90. additory/core/sample_data_manager.py +0 -492
  91. additory/core/user_namespace.py +0 -751
  92. additory/core/validator.py +0 -27
  93. additory/dynamic_api.py +0 -352
  94. additory/expressions/proxy.py +0 -549
  95. additory/expressions/registry.py +0 -313
  96. additory/expressions/samples.py +0 -492
  97. additory/synthetic/__init__.py +0 -13
  98. additory/synthetic/column_name_resolver.py +0 -149
  99. additory/synthetic/deduce.py +0 -259
  100. additory/synthetic/distributions.py +0 -22
  101. additory/synthetic/forecast.py +0 -1132
  102. additory/synthetic/linked_list_parser.py +0 -415
  103. additory/synthetic/namespace_lookup.py +0 -129
  104. additory/synthetic/smote.py +0 -320
  105. additory/synthetic/strategies.py +0 -926
  106. additory/synthetic/synthesizer.py +0 -713
  107. additory/utilities/__init__.py +0 -53
  108. additory/utilities/encoding.py +0 -600
  109. additory/utilities/games.py +0 -300
  110. additory/utilities/keys.py +0 -8
  111. additory/utilities/lookup.py +0 -103
  112. additory/utilities/matchers.py +0 -216
  113. additory/utilities/resolvers.py +0 -286
  114. additory/utilities/settings.py +0 -167
  115. additory/utilities/units.py +0 -749
  116. additory/utilities/validators.py +0 -153
  117. additory-0.1.0a4.dist-info/METADATA +0 -311
  118. additory-0.1.0a4.dist-info/RECORD +0 -72
  119. additory-0.1.0a4.dist-info/licenses/LICENSE +0 -21
  120. {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/WHEEL +0 -0
  121. {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,313 @@
1
+ """
2
+ Feature extraction utilities for Additory.
3
+
4
+ Provides functions to extract features from various column types
5
+ (datetime, email, text, URL, phone, etc.).
6
+ """
7
+
8
+ import re
9
+ from typing import Dict, List, Optional
10
+ import polars as pl
11
+
12
+
13
+ def extract_datetime_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
14
+ """
15
+ Extract datetime features from a datetime column.
16
+
17
+ Args:
18
+ series: Polars Series with datetime values
19
+ features: List of features to extract
20
+
21
+ Returns:
22
+ Dictionary mapping feature name to Series
23
+
24
+ Supported Features:
25
+ - 'year', 'month', 'day', 'hour', 'minute', 'second'
26
+ - 'day_of_week', 'day_of_year', 'week', 'quarter'
27
+ - 'time_of_day' ('morning', 'afternoon', 'evening', 'night')
28
+ - 'is_weekend', 'is_business_day'
29
+
30
+ Example:
31
+ features = extract_datetime_features(df['timestamp'], ['hour', 'day_of_week'])
32
+ """
33
+ result = {}
34
+
35
+ for feature in features:
36
+ if feature == 'year':
37
+ result[feature] = series.dt.year()
38
+ elif feature == 'month':
39
+ result[feature] = series.dt.month()
40
+ elif feature == 'day':
41
+ result[feature] = series.dt.day()
42
+ elif feature == 'hour':
43
+ result[feature] = series.dt.hour()
44
+ elif feature == 'minute':
45
+ result[feature] = series.dt.minute()
46
+ elif feature == 'second':
47
+ result[feature] = series.dt.second()
48
+ elif feature == 'day_of_week':
49
+ result[feature] = series.dt.weekday()
50
+ elif feature == 'day_of_year':
51
+ result[feature] = series.dt.ordinal_day()
52
+ elif feature == 'week':
53
+ result[feature] = series.dt.week()
54
+ elif feature == 'quarter':
55
+ result[feature] = series.dt.quarter()
56
+ elif feature == 'time_of_day':
57
+ hour = series.dt.hour()
58
+ time_of_day = pl.when(hour < 6).then(pl.lit('night')) \
59
+ .when(hour < 12).then(pl.lit('morning')) \
60
+ .when(hour < 18).then(pl.lit('afternoon')) \
61
+ .otherwise(pl.lit('evening'))
62
+ # Evaluate the expression to get a Series
63
+ result[feature] = pl.select(time_of_day.alias('time_of_day')).to_series()
64
+ elif feature == 'is_weekend':
65
+ weekday = series.dt.weekday()
66
+ result[feature] = (weekday >= 6)
67
+ elif feature == 'is_business_day':
68
+ weekday = series.dt.weekday()
69
+ result[feature] = (weekday < 5)
70
+ else:
71
+ raise ValueError(f"Unsupported datetime feature: {feature}")
72
+
73
+ return result
74
+
75
+
76
+ def extract_email_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
77
+ """
78
+ Extract features from email addresses.
79
+
80
+ Args:
81
+ series: Polars Series with email strings
82
+ features: List of features to extract
83
+
84
+ Returns:
85
+ Dictionary mapping feature name to Series
86
+
87
+ Supported Features:
88
+ - 'domain' - Email domain (e.g., 'gmail.com')
89
+ - 'username' - Username part (before @)
90
+ - 'tld' - Top-level domain (e.g., 'com')
91
+ - 'is_free_email' - Boolean (gmail, yahoo, etc.)
92
+ - 'is_corporate' - Boolean (not free email)
93
+
94
+ Example:
95
+ features = extract_email_features(df['email'], ['domain', 'is_free_email'])
96
+ """
97
+ result = {}
98
+
99
+ # Common free email providers
100
+ free_providers = {'gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com',
101
+ 'aol.com', 'icloud.com', 'mail.com', 'protonmail.com'}
102
+
103
+ for feature in features:
104
+ if feature == 'domain':
105
+ # Extract domain (part after @)
106
+ result[feature] = series.str.extract(r'@(.+)$', 1)
107
+ elif feature == 'username':
108
+ # Extract username (part before @)
109
+ result[feature] = series.str.extract(r'^([^@]+)@', 1)
110
+ elif feature == 'tld':
111
+ # Extract top-level domain (last part after .)
112
+ result[feature] = series.str.extract(r'\.([^.]+)$', 1)
113
+ elif feature == 'is_free_email':
114
+ # Check if domain is in free providers list
115
+ domain = series.str.extract(r'@(.+)$', 1)
116
+ result[feature] = domain.is_in(list(free_providers))
117
+ elif feature == 'is_corporate':
118
+ # Opposite of is_free_email
119
+ domain = series.str.extract(r'@(.+)$', 1)
120
+ result[feature] = ~domain.is_in(list(free_providers))
121
+ else:
122
+ raise ValueError(f"Unsupported email feature: {feature}")
123
+
124
+ return result
125
+
126
+
127
+ def extract_text_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
128
+ """
129
+ Extract features from text columns.
130
+
131
+ Args:
132
+ series: Polars Series with text strings
133
+ features: List of features to extract
134
+
135
+ Returns:
136
+ Dictionary mapping feature name to Series
137
+
138
+ Supported Features:
139
+ - 'length' - Character count
140
+ - 'word_count' - Number of words
141
+ - 'sentence_count' - Number of sentences
142
+ - 'has_numbers' - Boolean
143
+ - 'has_special_chars' - Boolean
144
+ - 'is_uppercase' - Boolean
145
+ - 'is_lowercase' - Boolean
146
+
147
+ Example:
148
+ features = extract_text_features(df['description'], ['length', 'word_count'])
149
+ """
150
+ result = {}
151
+
152
+ for feature in features:
153
+ if feature == 'length':
154
+ result[feature] = series.str.len_chars()
155
+ elif feature == 'word_count':
156
+ # Count words (split by whitespace)
157
+ result[feature] = series.str.split(' ').list.len()
158
+ elif feature == 'sentence_count':
159
+ # Count sentences (split by . ! ?)
160
+ result[feature] = series.str.count_matches(r'[.!?]')
161
+ elif feature == 'has_numbers':
162
+ result[feature] = series.str.contains(r'\d')
163
+ elif feature == 'has_special_chars':
164
+ result[feature] = series.str.contains(r'[^a-zA-Z0-9\s]')
165
+ elif feature == 'is_uppercase':
166
+ result[feature] = series.str.to_uppercase() == series
167
+ elif feature == 'is_lowercase':
168
+ result[feature] = series.str.to_lowercase() == series
169
+ else:
170
+ raise ValueError(f"Unsupported text feature: {feature}")
171
+
172
+ return result
173
+
174
+
175
+ def split_column(series: pl.Series, delimiter: str, max_splits: Optional[int] = None) -> List[pl.Series]:
176
+ """
177
+ Split column by delimiter into multiple columns.
178
+
179
+ Args:
180
+ series: Polars Series to split
181
+ delimiter: Delimiter string
182
+ max_splits: Maximum number of splits (None = unlimited)
183
+
184
+ Returns:
185
+ List of Series (one per split part)
186
+
187
+ Example:
188
+ # Split 'red,green,blue' into 3 columns
189
+ parts = split_column(df['tags'], delimiter=',', max_splits=3)
190
+ """
191
+ # Split the series
192
+ split_series = series.str.split(delimiter)
193
+
194
+ # Determine number of parts
195
+ if max_splits is None:
196
+ # Find maximum number of parts
197
+ max_len = split_series.list.len().max()
198
+ if max_len is None:
199
+ # Empty series
200
+ return []
201
+ max_parts = int(max_len)
202
+ else:
203
+ max_parts = max_splits
204
+
205
+ # Extract each part as a separate series
206
+ result = []
207
+ for i in range(max_parts):
208
+ try:
209
+ part = split_series.list.get(i, null_on_oob=True)
210
+ result.append(part)
211
+ except Exception:
212
+ # If index is out of bounds, append None series
213
+ result.append(pl.Series([None] * len(series)))
214
+
215
+ return result
216
+
217
+
218
+ def extract_url_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
219
+ """
220
+ Extract features from URLs.
221
+
222
+ Args:
223
+ series: Polars Series with URL strings
224
+ features: List of features to extract
225
+
226
+ Returns:
227
+ Dictionary mapping feature name to Series
228
+
229
+ Supported Features:
230
+ - 'protocol' - http, https, ftp, etc.
231
+ - 'domain' - Domain name
232
+ - 'path' - URL path
233
+ - 'is_secure' - Boolean (https)
234
+
235
+ Example:
236
+ features = extract_url_features(df['url'], ['protocol', 'domain'])
237
+ """
238
+ result = {}
239
+
240
+ for feature in features:
241
+ if feature == 'protocol':
242
+ # Extract protocol (e.g., http, https)
243
+ result[feature] = series.str.extract(r'^([a-z]+)://', 1)
244
+ elif feature == 'domain':
245
+ # Extract domain (between :// and /)
246
+ result[feature] = series.str.extract(r'://([^/]+)', 1)
247
+ elif feature == 'path':
248
+ # Extract path (after domain)
249
+ result[feature] = series.str.extract(r'://[^/]+(/[^?#]*)', 1)
250
+ elif feature == 'is_secure':
251
+ # Check if protocol is https
252
+ protocol = series.str.extract(r'^([a-z]+)://', 1)
253
+ result[feature] = (protocol == 'https')
254
+ else:
255
+ raise ValueError(f"Unsupported URL feature: {feature}")
256
+
257
+ return result
258
+
259
+
260
+ def extract_phone_features(series: pl.Series, features: List[str]) -> Dict[str, pl.Series]:
261
+ """
262
+ Extract features from phone numbers.
263
+
264
+ Args:
265
+ series: Polars Series with phone number strings
266
+ features: List of features to extract
267
+
268
+ Returns:
269
+ Dictionary mapping feature name to Series
270
+
271
+ Supported Features:
272
+ - 'country_code' - Country code
273
+ - 'area_code' - Area code
274
+ - 'is_valid' - Boolean (basic validation)
275
+
276
+ Example:
277
+ features = extract_phone_features(df['phone'], ['country_code', 'area_code'])
278
+ """
279
+ result = {}
280
+
281
+ for feature in features:
282
+ if feature == 'country_code':
283
+ # Extract country code (e.g., +1, +44)
284
+ result[feature] = series.str.extract(r'^\+?(\d{1,3})', 1)
285
+ elif feature == 'area_code':
286
+ # Extract area code (3 digits after country code)
287
+ result[feature] = series.str.extract(r'\(?(\d{3})\)?', 1)
288
+ elif feature == 'is_valid':
289
+ # Basic validation: has at least 10 digits
290
+ digits_only = series.str.replace_all(r'\D', '')
291
+ result[feature] = digits_only.str.len_chars() >= 10
292
+ else:
293
+ raise ValueError(f"Unsupported phone feature: {feature}")
294
+
295
+ return result
296
+
297
+
298
+ def parse_pattern(series: pl.Series, pattern: str) -> pl.Series:
299
+ """
300
+ Extract pattern from strings using regex.
301
+
302
+ Args:
303
+ series: Polars Series with strings
304
+ pattern: Regex pattern to extract
305
+
306
+ Returns:
307
+ Series with extracted values
308
+
309
+ Example:
310
+ # Extract numbers from strings
311
+ numbers = parse_pattern(df['text'], r'\\d+')
312
+ """
313
+ return series.str.extract(pattern, 0)
@@ -0,0 +1,332 @@
1
+ """
2
+ K-Nearest Neighbors imputation utilities for Additory.
3
+
4
+ Provides KNN-based imputation for missing values in numeric columns.
5
+ """
6
+
7
+ import math
8
+ from typing import List, Optional
9
+ import polars as pl
10
+ import numpy as np
11
+
12
+
13
+ def knn_impute(df: pl.DataFrame, columns: List[str], k: int = 5,
14
+ weights: str = 'distance', metric: str = 'euclidean') -> pl.DataFrame:
15
+ """
16
+ Impute missing values using K-Nearest Neighbors.
17
+
18
+ Args:
19
+ df: DataFrame with missing values
20
+ columns: Columns to impute
21
+ k: Number of neighbors (default: 5)
22
+ weights: Weighting strategy ('uniform' or 'distance')
23
+ metric: Distance metric ('euclidean', 'manhattan', 'cosine')
24
+
25
+ Returns:
26
+ DataFrame with imputed values
27
+
28
+ Example:
29
+ imputed_df = knn_impute(df, columns=['age', 'income'], k=5, weights='distance')
30
+ """
31
+ # Validate parameters
32
+ validate_knn_parameters(df, columns, k)
33
+
34
+ # Create a copy to avoid modifying original
35
+ result_df = df.clone()
36
+
37
+ # Get all numeric columns for distance calculation
38
+ all_numeric_cols = [col for col in df.columns
39
+ if df[col].dtype in [pl.Int8, pl.Int16, pl.Int32, pl.Int64,
40
+ pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
41
+ pl.Float32, pl.Float64]]
42
+
43
+ # Convert to numpy for easier manipulation
44
+ data = df.select(columns).to_numpy().copy() # Make a copy we can modify
45
+
46
+ # For each row with missing values
47
+ for row_idx in range(len(df)):
48
+ row = data[row_idx]
49
+
50
+ # Check if this row has any missing values
51
+ if not np.any(np.isnan(row)):
52
+ continue
53
+
54
+ # Check if this row has at least some non-missing values in ANY numeric column
55
+ # (needed for distance calculation)
56
+ row_all_cols = df.select(all_numeric_cols).row(row_idx)
57
+ has_any_non_missing = any(val is not None and not (isinstance(val, float) and np.isnan(val))
58
+ for val in row_all_cols)
59
+ if not has_any_non_missing:
60
+ # All numeric values missing, skip this row
61
+ continue
62
+
63
+ # Calculate distances to all other rows using ALL numeric columns
64
+ distances = calculate_distances(df, row_idx, all_numeric_cols, metric)
65
+
66
+ # Find k nearest neighbors
67
+ neighbor_indices = find_k_nearest(distances, k, exclude_idx=row_idx)
68
+
69
+ # Impute each missing value
70
+ for col_idx, col_name in enumerate(columns):
71
+ if np.isnan(row[col_idx]):
72
+ # Get values from neighbors
73
+ neighbor_values = []
74
+ neighbor_distances = []
75
+
76
+ for neighbor_idx in neighbor_indices:
77
+ neighbor_value = data[neighbor_idx, col_idx]
78
+ if not np.isnan(neighbor_value):
79
+ neighbor_values.append(neighbor_value)
80
+ neighbor_distances.append(distances[neighbor_idx])
81
+
82
+ # If we have neighbor values, compute weighted average
83
+ if neighbor_values:
84
+ neighbor_values_series = pl.Series(neighbor_values)
85
+ neighbor_distances_series = pl.Series(neighbor_distances)
86
+
87
+ imputed_value = compute_weighted_average(
88
+ neighbor_values_series,
89
+ neighbor_distances_series,
90
+ weights
91
+ )
92
+
93
+ # Update the data array
94
+ data[row_idx, col_idx] = imputed_value
95
+
96
+ # Replace the columns in the result DataFrame with imputed data
97
+ for col_idx, col_name in enumerate(columns):
98
+ result_df = result_df.with_columns(
99
+ pl.Series(col_name, data[:, col_idx])
100
+ )
101
+
102
+ return result_df
103
+
104
+
105
+ def calculate_distances(df: pl.DataFrame, row_idx: int, columns: List[str],
106
+ metric: str) -> pl.Series:
107
+ """
108
+ Calculate distances from a row to all other rows.
109
+
110
+ Args:
111
+ df: DataFrame
112
+ row_idx: Index of row to calculate distances from
113
+ columns: Columns to use for distance calculation
114
+ metric: Distance metric
115
+
116
+ Returns:
117
+ Series of distances
118
+ """
119
+ data = df.select(columns).to_numpy()
120
+ row = data[row_idx]
121
+
122
+ distances = []
123
+ for i in range(len(df)):
124
+ if i == row_idx:
125
+ distances.append(float('inf')) # Exclude self
126
+ else:
127
+ other_row = data[i]
128
+
129
+ # Only use non-missing values for distance calculation
130
+ mask = ~(np.isnan(row) | np.isnan(other_row))
131
+
132
+ if not np.any(mask):
133
+ # No common non-missing values
134
+ distances.append(float('inf'))
135
+ else:
136
+ row_clean = row[mask]
137
+ other_clean = other_row[mask]
138
+
139
+ if metric == 'euclidean':
140
+ dist = euclidean_distance(
141
+ pl.Series(row_clean),
142
+ pl.Series(other_clean)
143
+ )
144
+ elif metric == 'manhattan':
145
+ dist = manhattan_distance(
146
+ pl.Series(row_clean),
147
+ pl.Series(other_clean)
148
+ )
149
+ elif metric == 'cosine':
150
+ dist = cosine_distance(
151
+ pl.Series(row_clean),
152
+ pl.Series(other_clean)
153
+ )
154
+ else:
155
+ raise ValueError(f"Unsupported metric: {metric}")
156
+
157
+ distances.append(dist)
158
+
159
+ return pl.Series(distances)
160
+
161
+
162
+ def find_k_nearest(distances: pl.Series, k: int, exclude_idx: Optional[int] = None) -> List[int]:
163
+ """
164
+ Find indices of k nearest neighbors.
165
+
166
+ Args:
167
+ distances: Series of distances
168
+ k: Number of neighbors to find
169
+ exclude_idx: Index to exclude (the row itself)
170
+
171
+ Returns:
172
+ List of k nearest neighbor indices
173
+ """
174
+ # Convert to numpy for easier manipulation
175
+ dist_array = distances.to_numpy()
176
+
177
+ # Get indices sorted by distance
178
+ sorted_indices = np.argsort(dist_array)
179
+
180
+ # Filter out infinite distances and excluded index
181
+ valid_indices = []
182
+ for idx in sorted_indices:
183
+ if not np.isinf(dist_array[idx]):
184
+ if exclude_idx is None or idx != exclude_idx:
185
+ valid_indices.append(int(idx))
186
+ if len(valid_indices) >= k:
187
+ break
188
+
189
+ return valid_indices
190
+
191
+
192
+ def compute_weighted_average(values: pl.Series, distances: pl.Series, weights: str) -> float:
193
+ """
194
+ Compute weighted average of neighbor values.
195
+
196
+ Args:
197
+ values: Values from neighbors
198
+ distances: Distances to neighbors
199
+ weights: Weighting strategy ('uniform' or 'distance')
200
+
201
+ Returns:
202
+ Weighted average value
203
+ """
204
+ if weights == 'uniform':
205
+ # Simple average
206
+ return float(values.mean())
207
+
208
+ elif weights == 'distance':
209
+ # Inverse distance weighting
210
+ values_array = values.to_numpy()
211
+ distances_array = distances.to_numpy()
212
+
213
+ # Avoid division by zero for very small distances
214
+ distances_array = np.maximum(distances_array, 1e-10)
215
+
216
+ # Inverse distance weights
217
+ inv_distances = 1.0 / distances_array
218
+
219
+ # Weighted average
220
+ weighted_sum = np.sum(values_array * inv_distances)
221
+ weight_sum = np.sum(inv_distances)
222
+
223
+ return float(weighted_sum / weight_sum)
224
+
225
+ else:
226
+ raise ValueError(f"Unsupported weighting strategy: {weights}")
227
+
228
+
229
+ def euclidean_distance(row1: pl.Series, row2: pl.Series) -> float:
230
+ """
231
+ Calculate Euclidean distance between two rows.
232
+
233
+ Args:
234
+ row1: First row values
235
+ row2: Second row values
236
+
237
+ Returns:
238
+ Euclidean distance
239
+ """
240
+ diff = row1.to_numpy() - row2.to_numpy()
241
+ return float(math.sqrt(np.sum(diff ** 2)))
242
+
243
+
244
+ def manhattan_distance(row1: pl.Series, row2: pl.Series) -> float:
245
+ """
246
+ Calculate Manhattan distance between two rows.
247
+
248
+ Args:
249
+ row1: First row values
250
+ row2: Second row values
251
+
252
+ Returns:
253
+ Manhattan distance
254
+ """
255
+ diff = np.abs(row1.to_numpy() - row2.to_numpy())
256
+ return float(np.sum(diff))
257
+
258
+
259
+ def cosine_distance(row1: pl.Series, row2: pl.Series) -> float:
260
+ """
261
+ Calculate Cosine distance between two rows.
262
+
263
+ Args:
264
+ row1: First row values
265
+ row2: Second row values
266
+
267
+ Returns:
268
+ Cosine distance
269
+ """
270
+ vec1 = row1.to_numpy()
271
+ vec2 = row2.to_numpy()
272
+
273
+ # Calculate dot product and norms
274
+ dot_product = np.dot(vec1, vec2)
275
+ norm1 = np.linalg.norm(vec1)
276
+ norm2 = np.linalg.norm(vec2)
277
+
278
+ # Avoid division by zero
279
+ if norm1 == 0 or norm2 == 0:
280
+ return 1.0 # Maximum distance
281
+
282
+ # Cosine similarity
283
+ cosine_sim = dot_product / (norm1 * norm2)
284
+
285
+ # Cosine distance (1 - similarity)
286
+ return float(1.0 - cosine_sim)
287
+
288
+
289
+ def validate_knn_parameters(df: pl.DataFrame, columns: List[str], k: int) -> bool:
290
+ """
291
+ Validate KNN imputation parameters.
292
+
293
+ Args:
294
+ df: DataFrame
295
+ columns: Columns to impute
296
+ k: Number of neighbors
297
+
298
+ Returns:
299
+ True if valid
300
+
301
+ Raises:
302
+ ValueError: If parameters are invalid
303
+ """
304
+ # Check columns exist
305
+ for col in columns:
306
+ if col not in df.columns:
307
+ raise ValueError(f"Column '{col}' not found in DataFrame")
308
+
309
+ # Check at least some non-missing values exist (before checking dtype)
310
+ # This is important because a column with all nulls has dtype Null
311
+ for col in columns:
312
+ non_null_count = df[col].null_count()
313
+ if non_null_count == len(df):
314
+ raise ValueError(f"Column '{col}' has all missing values, cannot impute")
315
+
316
+ # Check columns are numeric
317
+ for col in columns:
318
+ dtype = df[col].dtype
319
+ if dtype not in [pl.Int8, pl.Int16, pl.Int32, pl.Int64,
320
+ pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
321
+ pl.Float32, pl.Float64]:
322
+ raise ValueError(f"Column '{col}' must be numeric, got {dtype}")
323
+
324
+ # Check k is positive
325
+ if k <= 0:
326
+ raise ValueError(f"k must be positive, got {k}")
327
+
328
+ # Check k is less than number of rows
329
+ if k >= len(df):
330
+ raise ValueError(f"k ({k}) must be less than number of rows ({len(df)})")
331
+
332
+ return True