kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +115 -107
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/pquery/training_table.py +16 -2
  31. kumoai/testing/snow.py +3 -3
  32. kumoai/trainer/distilled_trainer.py +175 -0
  33. kumoai/utils/display.py +87 -0
  34. kumoai/utils/progress_logger.py +15 -2
  35. kumoai/utils/sql.py +2 -2
  36. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
  37. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +40 -35
  38. kumoai/experimental/rfm/base/column_expression.py +0 -50
  39. kumoai/experimental/rfm/base/sql_table.py +0 -229
  40. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  41. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  42. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,21 @@
1
1
  import json
2
+ import math
2
3
  from collections.abc import Iterator
3
4
  from contextlib import contextmanager
5
+ from typing import TYPE_CHECKING, cast
4
6
 
5
7
  import numpy as np
6
8
  import pandas as pd
7
9
  import pyarrow as pa
8
10
  from kumoapi.pquery import ValidatedPredictiveQuery
9
11
 
10
- from kumoai.experimental.rfm.backend.snow import Connection
11
- from kumoai.experimental.rfm.base import SQLSampler
12
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
13
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
14
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
- from kumoai.utils import quote_ident
15
+ from kumoai.utils import ProgressLogger, quote_ident
16
+
17
+ if TYPE_CHECKING:
18
+ from kumoai.experimental.rfm import Graph
14
19
 
15
20
 
16
21
  @contextmanager
@@ -22,30 +27,51 @@ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
27
 
23
28
 
24
29
  class SnowSampler(SQLSampler):
30
+ def __init__(
31
+ self,
32
+ graph: 'Graph',
33
+ verbose: bool | ProgressLogger = True,
34
+ ) -> None:
35
+ super().__init__(graph=graph, verbose=verbose)
36
+
37
+ for table in graph.tables.values():
38
+ assert isinstance(table, SnowTable)
39
+ self._connection = table._connection
40
+
41
+ self._num_rows_dict: dict[str, int] = {
42
+ table.name: cast(int, table._num_rows)
43
+ for table in graph.tables.values()
44
+ }
45
+
46
+ @property
47
+ def num_rows_dict(self) -> dict[str, int]:
48
+ return self._num_rows_dict
49
+
25
50
  def _get_min_max_time_dict(
26
51
  self,
27
52
  table_names: list[str],
28
53
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
29
54
  selects: list[str] = []
30
55
  for table_name in table_names:
31
- time_column = self.time_column_dict[table_name]
56
+ column = self.time_column_dict[table_name]
57
+ column_ref = self.table_column_ref_dict[table_name][column]
58
+ ident = quote_ident(table_name, char="'")
32
59
  select = (f"SELECT\n"
33
- f" ? as table_name,\n"
34
- f" MIN({quote_ident(time_column)}) as min_date,\n"
35
- f" MAX({quote_ident(time_column)}) as max_date\n"
36
- f"FROM {self.fqn_dict[table_name]}")
60
+ f" {ident} as table_name,\n"
61
+ f" MIN({column_ref}) as min_date,\n"
62
+ f" MAX({column_ref}) as max_date\n"
63
+ f"FROM {self.source_name_dict[table_name]}")
37
64
  selects.append(select)
38
65
  sql = "\nUNION ALL\n".join(selects)
39
66
 
40
67
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
41
- with paramstyle(self._connection), self._connection.cursor() as cursor:
42
- cursor.execute(sql, table_names)
43
- rows = cursor.fetchall()
44
- for table_name, _min, _max in rows:
45
- out_dict[table_name] = (
46
- pd.Timestamp.max if _min is None else pd.Timestamp(_min),
47
- pd.Timestamp.min if _max is None else pd.Timestamp(_max),
48
- )
68
+ with self._connection.cursor() as cursor:
69
+ cursor.execute(sql)
70
+ for table_name, _min, _max in cursor.fetchall():
71
+ out_dict[table_name] = (
72
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
73
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
74
+ )
49
75
 
50
76
  return out_dict
51
77
 
@@ -59,17 +85,27 @@ class SnowSampler(SQLSampler):
59
85
  # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
60
86
  num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
61
87
 
88
+ source_table = self.source_table_dict[table_name]
62
89
  filters: list[str] = []
63
- primary_key = self.primary_key_dict[table_name]
64
- if self.source_table_dict[table_name][primary_key].is_nullable:
65
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
66
- time_column = self.time_column_dict.get(table_name)
67
- if (time_column is not None and
68
- self.source_table_dict[table_name][time_column].is_nullable):
69
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
70
90
 
71
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
72
- f"FROM {self.fqn_dict[table_name]}\n"
91
+ key = self.primary_key_dict[table_name]
92
+ if key not in source_table or source_table[key].is_nullable:
93
+ key_ref = self.table_column_ref_dict[table_name][key]
94
+ filters.append(f" {key_ref} IS NOT NULL")
95
+
96
+ column = self.time_column_dict.get(table_name)
97
+ if column is None:
98
+ pass
99
+ elif column not in source_table or source_table[column].is_nullable:
100
+ column_ref = self.table_column_ref_dict[table_name][column]
101
+ filters.append(f" {column_ref} IS NOT NULL")
102
+
103
+ projections = [
104
+ self.table_column_proj_dict[table_name][column]
105
+ for column in columns
106
+ ]
107
+ sql = (f"SELECT {', '.join(projections)}\n"
108
+ f"FROM {self.source_name_dict[table_name]}\n"
73
109
  f"SAMPLE ROW ({num_rows} ROWS)")
74
110
  if len(filters) > 0:
75
111
  sql += f"\nWHERE{' AND'.join(filters)}"
@@ -79,7 +115,11 @@ class SnowSampler(SQLSampler):
79
115
  cursor.execute(sql)
80
116
  table = cursor.fetch_arrow_all()
81
117
 
82
- return self._sanitize(table_name, table)
118
+ return Table._sanitize(
119
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
120
+ dtype_dict=self.table_dtype_dict[table_name],
121
+ stype_dict=self.table_stype_dict[table_name],
122
+ )
83
123
 
84
124
  def _sample_target(
85
125
  self,
@@ -114,11 +154,11 @@ class SnowSampler(SQLSampler):
114
154
  query.entity_table: np.arange(len(entity_df)),
115
155
  }
116
156
  for edge_type, (min_offset, max_offset) in time_offset_dict.items():
117
- table_name, fkey, _ = edge_type
157
+ table_name, foreign_key, _ = edge_type
118
158
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
119
159
  table_name=table_name,
120
- fkey=fkey,
121
- pkey=entity_df[self.primary_key_dict[query.entity_table]],
160
+ foreign_key=foreign_key,
161
+ index=entity_df[self.primary_key_dict[query.entity_table]],
122
162
  anchor_time=time,
123
163
  min_offset=min_offset,
124
164
  max_offset=max_offset,
@@ -149,104 +189,219 @@ class SnowSampler(SQLSampler):
149
189
  def _by_pkey(
150
190
  self,
151
191
  table_name: str,
152
- pkey: pd.Series,
192
+ index: pd.Series,
153
193
  columns: set[str],
154
194
  ) -> tuple[pd.DataFrame, np.ndarray]:
195
+ key = self.primary_key_dict[table_name]
196
+ key_ref = self.table_column_ref_dict[table_name][key]
197
+ projections = [
198
+ self.table_column_proj_dict[table_name][column]
199
+ for column in columns
200
+ ]
155
201
 
156
- pkey_name = self.primary_key_dict[table_name]
157
- source_table = self.source_table_dict[table_name]
158
-
159
- payload = json.dumps(list(pkey))
202
+ payload = json.dumps(list(index))
160
203
 
161
204
  sql = ("WITH TMP as (\n"
162
205
  " SELECT\n"
163
- " f.index as BATCH,\n")
164
- if source_table[pkey_name].dtype.is_int():
165
- sql += " f.value::NUMBER as ID\n"
166
- elif source_table[pkey_name].dtype.is_float():
167
- sql += " f.value::FLOAT as ID\n"
206
+ " f.index as __KUMO_BATCH__,\n")
207
+ if self.table_dtype_dict[table_name][key].is_int():
208
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
209
+ elif self.table_dtype_dict[table_name][key].is_float():
210
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
168
211
  else:
169
- sql += " f.value::VARCHAR as ID\n"
212
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
170
213
  sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
171
214
  f")\n"
172
- f"SELECT TMP.BATCH as __BATCH__, "
173
- f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
215
+ f"SELECT "
216
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
217
+ f"{', '.join(projections)}\n"
174
218
  f"FROM TMP\n"
175
- f"JOIN {self.fqn_dict[table_name]} ENT\n"
176
- f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
219
+ f"JOIN {self.source_name_dict[table_name]}\n"
220
+ f" ON {key_ref} = TMP.__KUMO_ID__")
177
221
 
178
222
  with paramstyle(self._connection), self._connection.cursor() as cursor:
179
223
  cursor.execute(sql, (payload, ))
180
224
  table = cursor.fetch_arrow_all()
181
225
 
182
226
  # Remove any duplicated primary keys in post-processing:
183
- tmp = table.append_column('__TMP__', pa.array(range(len(table))))
184
- gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
185
- table = table.take(gb['__TMP___min'])
227
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
228
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
229
+ table = table.take(gb['__KUMO_ID___min'])
230
+
231
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
232
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
233
+ table = table.remove_column(batch_index)
186
234
 
187
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
188
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
235
+ return Table._sanitize(
236
+ df=table.to_pandas(),
237
+ dtype_dict=self.table_dtype_dict[table_name],
238
+ stype_dict=self.table_stype_dict[table_name],
239
+ ), batch
189
240
 
190
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
241
+ def _by_fkey(
242
+ self,
243
+ table_name: str,
244
+ foreign_key: str,
245
+ index: pd.Series,
246
+ num_neighbors: int,
247
+ anchor_time: pd.Series | None,
248
+ columns: set[str],
249
+ ) -> tuple[pd.DataFrame, np.ndarray]:
250
+ time_column = self.time_column_dict.get(table_name)
251
+
252
+ end_time: pd.Series | None = None
253
+ start_time: pd.Series | None = None
254
+ if time_column is not None and anchor_time is not None:
255
+ # In order to avoid a full table scan, we limit foreign key
256
+ # sampling to a certain time range, approximated by the number of
257
+ # rows, timestamp ranges and `num_neighbors` value.
258
+ # Downstream, this helps Snowflake to apply partition pruning:
259
+ dst_table_name = [
260
+ dst_table
261
+ for key, dst_table in self.foreign_key_dict[table_name]
262
+ if key == foreign_key
263
+ ][0]
264
+ num_facts = self.num_rows_dict[table_name]
265
+ num_entities = self.num_rows_dict[dst_table_name]
266
+ min_time = self.get_min_time([table_name])
267
+ max_time = self.get_max_time([table_name])
268
+ freq = num_facts / num_entities
269
+ freq = freq / max((max_time - min_time).total_seconds(), 1)
270
+ offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
271
+
272
+ end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
273
+ start_time = anchor_time - offset
274
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
275
+ payload = json.dumps(list(zip(index, end_time, start_time)))
276
+ else:
277
+ payload = json.dumps(list(zip(index)))
278
+
279
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
280
+ projections = [
281
+ self.table_column_proj_dict[table_name][column]
282
+ for column in columns
283
+ ]
284
+
285
+ sql = ("WITH TMP as (\n"
286
+ " SELECT\n"
287
+ " f.index as __KUMO_BATCH__,\n")
288
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
289
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
290
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
291
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
292
+ else:
293
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
294
+ if end_time is not None and start_time is not None:
295
+ sql += (",\n"
296
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
297
+ " f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
298
+ sql += (f"\n"
299
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
300
+ f")\n"
301
+ f"SELECT "
302
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
303
+ f"{', '.join(projections)}\n"
304
+ f"FROM TMP\n"
305
+ f"JOIN {self.source_name_dict[table_name]}\n"
306
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
307
+ if end_time is not None and start_time is not None:
308
+ assert time_column is not None
309
+ time_ref = self.table_column_ref_dict[table_name][time_column]
310
+ sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
311
+ f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
312
+ f"WHERE {time_ref} <= '{end_time.max()}'\n"
313
+ f" AND {time_ref} > '{start_time.min()}'\n")
314
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
315
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
316
+ if time_column is not None:
317
+ sql += f" ORDER BY {time_ref} DESC\n"
318
+ else:
319
+ sql += f" ORDER BY {key_ref}\n"
320
+ sql += f") <= {num_neighbors}"
321
+
322
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
323
+ cursor.execute(sql, (payload, ))
324
+ table = cursor.fetch_arrow_all()
325
+
326
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
327
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
328
+ table = table.remove_column(batch_index)
329
+
330
+ return Table._sanitize(
331
+ df=table.to_pandas(),
332
+ dtype_dict=self.table_dtype_dict[table_name],
333
+ stype_dict=self.table_stype_dict[table_name],
334
+ ), batch
191
335
 
192
336
  # Helper Methods ##########################################################
193
337
 
194
338
  def _by_time(
195
339
  self,
196
340
  table_name: str,
197
- fkey: str,
198
- pkey: pd.Series,
341
+ foreign_key: str,
342
+ index: pd.Series,
199
343
  anchor_time: pd.Series,
200
344
  min_offset: pd.DateOffset | None,
201
345
  max_offset: pd.DateOffset,
202
346
  columns: set[str],
203
347
  ) -> tuple[pd.DataFrame, np.ndarray]:
348
+ time_column = self.time_column_dict[table_name]
204
349
 
205
350
  end_time = anchor_time + max_offset
206
351
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
352
+ start_time: pd.Series | None = None
207
353
  if min_offset is not None:
208
354
  start_time = anchor_time + min_offset
209
355
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
210
- payload = json.dumps(list(zip(pkey, end_time, start_time)))
356
+ payload = json.dumps(list(zip(index, end_time, start_time)))
211
357
  else:
212
- payload = json.dumps(list(zip(pkey, end_time)))
213
-
214
- # Based on benchmarking, JSON payload is the fastest way to query by
215
- # custom indices (compared to large `IN` clauses or temporary tables):
216
- source_table = self.source_table_dict[table_name]
217
- time_column = self.time_column_dict[table_name]
358
+ payload = json.dumps(list(zip(index, end_time)))
359
+
360
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
361
+ time_ref = self.table_column_ref_dict[table_name][time_column]
362
+ projections = [
363
+ self.table_column_proj_dict[table_name][column]
364
+ for column in columns
365
+ ]
218
366
  sql = ("WITH TMP as (\n"
219
367
  " SELECT\n"
220
- " f.index as BATCH,\n")
221
- if source_table[fkey].dtype.is_int():
222
- sql += " f.value[0]::NUMBER as ID,\n"
223
- elif source_table[fkey].dtype.is_float():
224
- sql += " f.value[0]::FLOAT as ID,\n"
368
+ " f.index as __KUMO_BATCH__,\n")
369
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
370
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
371
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
372
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
225
373
  else:
226
- sql += " f.value[0]::VARCHAR as ID,\n"
227
- sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
374
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
375
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
228
376
  if min_offset is not None:
229
- sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
377
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
230
378
  sql += (f"\n"
231
379
  f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
232
380
  f")\n"
233
- f"SELECT TMP.BATCH as __BATCH__, "
234
- f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
381
+ f"SELECT "
382
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
383
+ f"{', '.join(projections)}\n"
235
384
  f"FROM TMP\n"
236
- f"JOIN {self.fqn_dict[table_name]} FACT\n"
237
- f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
238
- f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
239
- if min_offset is not None:
240
- sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
385
+ f"JOIN {self.source_name_dict[table_name]}\n"
386
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
387
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
388
+ if start_time is not None:
389
+ sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
390
+ # Add global time bounds to enable partition pruning:
391
+ sql += f"WHERE {time_ref} <= '{end_time.max()}'"
392
+ if start_time is not None:
393
+ sql += f"\nAND {time_ref} > '{start_time.min()}'"
241
394
 
242
395
  with paramstyle(self._connection), self._connection.cursor() as cursor:
243
396
  cursor.execute(sql, (payload, ))
244
397
  table = cursor.fetch_arrow_all()
245
398
 
246
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
247
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
248
-
249
- return self._sanitize(table_name, table), batch
399
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
400
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
401
+ table = table.remove_column(batch_index)
250
402
 
251
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
252
- return table.to_pandas(types_mapper=pd.ArrowDtype)
403
+ return Table._sanitize(
404
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
405
+ dtype_dict=self.table_dtype_dict[table_name],
406
+ stype_dict=self.table_stype_dict[table_name],
407
+ ), batch