kumoai 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl → 2.13.0.dev202512040252__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,25 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, List, Optional, Sequence, Tuple
2
+ from collections import defaultdict
3
+ from functools import cached_property
4
+ from typing import Dict, List, Optional, Sequence, Set
3
5
 
4
6
  import pandas as pd
5
7
  from kumoapi.source_table import UnavailableSourceTable
6
8
  from kumoapi.table import Column as ColumnDefinition
7
9
  from kumoapi.table import TableDefinition
8
- from kumoapi.typing import Dtype, Stype
10
+ from kumoapi.typing import Stype
9
11
  from typing_extensions import Self
10
12
 
11
13
  from kumoai import in_notebook
12
- from kumoai.experimental.rfm.base import Column
14
+ from kumoai.experimental.rfm.base import Column, SourceColumn, SourceForeignKey
15
+ from kumoai.experimental.rfm.infer import (
16
+ contains_categorical,
17
+ contains_id,
18
+ contains_multicategorical,
19
+ contains_timestamp,
20
+ infer_primary_key,
21
+ infer_time_column,
22
+ )
13
23
 
14
24
 
15
25
  class Table(ABC):
@@ -39,8 +49,30 @@ class Table(ABC):
39
49
  self._time_column: Optional[str] = None
40
50
  self._end_time_column: Optional[str] = None
41
51
 
52
+ if len(self._source_column_dict) == 0:
53
+ raise ValueError(f"Table '{name}' does not hold any column with "
54
+ f"a supported data type")
55
+
56
+ primary_keys = [
57
+ column.name for column in self._source_column_dict.values()
58
+ if column.is_primary_key
59
+ ]
60
+ if len(primary_keys) == 1: # NOTE No composite keys yet.
61
+ if primary_key is not None and primary_key != primary_keys[0]:
62
+ raise ValueError(f"Found duplicate primary key "
63
+ f"definition '{primary_key}' and "
64
+ f"'{primary_keys[0]}' in table '{name}'")
65
+ primary_key = primary_keys[0]
66
+
67
+ unique_keys = [
68
+ column.name for column in self._source_column_dict.values()
69
+ if column.is_unique_key
70
+ ]
71
+ if primary_key is None and len(unique_keys) == 1:
72
+ primary_key = unique_keys[0]
73
+
42
74
  self._columns: Dict[str, Column] = {}
43
- for column_name in columns or []:
75
+ for column_name in columns or list(self._source_column_dict.keys()):
44
76
  self.add_column(column_name)
45
77
 
46
78
  if primary_key is not None:
@@ -104,12 +136,12 @@ class Table(ABC):
104
136
  raise KeyError(f"Column '{name}' already exists in table "
105
137
  f"'{self.name}'")
106
138
 
107
- if not self._has_source_column(name):
139
+ if name not in self._source_column_dict:
108
140
  raise KeyError(f"Column '{name}' does not exist in the underlying "
109
141
  f"source table")
110
142
 
111
143
  try:
112
- dtype = self._get_source_dtype(name)
144
+ dtype = self._source_column_dict[name].dtype
113
145
  except Exception as e:
114
146
  raise RuntimeError(f"Could not obtain data type for column "
115
147
  f"'{name}' in table '{self.name}'. Change "
@@ -117,7 +149,17 @@ class Table(ABC):
117
149
  f"table or remove it from the table.") from e
118
150
 
119
151
  try:
120
- stype = self._get_source_stype(name, dtype)
152
+ ser = self._sample_df[name]
153
+ if contains_id(ser, name, dtype):
154
+ stype = Stype.ID
155
+ elif contains_timestamp(ser, name, dtype):
156
+ stype = Stype.timestamp
157
+ elif contains_multicategorical(ser, name, dtype):
158
+ stype = Stype.multicategorical
159
+ elif contains_categorical(ser, name, dtype):
160
+ stype = Stype.categorical
161
+ else:
162
+ stype = dtype.default_stype
121
163
  except Exception as e:
122
164
  raise RuntimeError(f"Could not obtain semantic type for column "
123
165
  f"'{name}' in table '{self.name}'. Change "
@@ -338,8 +380,9 @@ class Table(ABC):
338
380
 
339
381
  def print_metadata(self) -> None:
340
382
  r"""Prints the :meth:`~metadata` of this table."""
341
- num_rows = self._num_rows()
342
- num_rows_repr = ' ({num_rows:,} rows)' if num_rows is not None else ''
383
+ num_rows_repr = ''
384
+ if self._num_rows is not None:
385
+ num_rows_repr = ' ({self._num_rows:,} rows)'
343
386
 
344
387
  if in_notebook():
345
388
  from IPython.display import Markdown, display
@@ -384,7 +427,11 @@ class Table(ABC):
384
427
  column.name for column in self.columns if is_candidate(column)
385
428
  ]
386
429
 
387
- if primary_key := self._infer_primary_key(candidates):
430
+ if primary_key := infer_primary_key(
431
+ table_name=self.name,
432
+ df=self._sample_df,
433
+ candidates=candidates,
434
+ ):
388
435
  self.primary_key = primary_key
389
436
  logs.append(f"primary key '{primary_key}'")
390
437
 
@@ -395,7 +442,10 @@ class Table(ABC):
395
442
  if column.stype == Stype.timestamp
396
443
  and column.name != self._end_time_column
397
444
  ]
398
- if time_column := self._infer_time_column(candidates):
445
+ if time_column := infer_time_column(
446
+ df=self._sample_df,
447
+ candidates=candidates,
448
+ ):
399
449
  self.time_column = time_column
400
450
  logs.append(f"time column '{time_column}'")
401
451
 
@@ -448,30 +498,43 @@ class Table(ABC):
448
498
 
449
499
  # Abstract method #########################################################
450
500
 
451
- @abstractmethod
452
- def _has_source_column(self, name: str) -> bool:
453
- pass
501
+ @cached_property
502
+ def _source_column_dict(self) -> Dict[str, SourceColumn]:
503
+ return {col.name: col for col in self._get_source_columns()}
454
504
 
455
505
  @abstractmethod
456
- def _get_source_dtype(self, name: str) -> Dtype:
506
+ def _get_source_columns(self) -> List[SourceColumn]:
457
507
  pass
458
508
 
459
- @abstractmethod
460
- def _get_source_stype(self, name: str, dtype: Dtype) -> Stype:
461
- pass
509
+ @cached_property
510
+ def _source_foreign_key_dict(self) -> Dict[str, SourceForeignKey]:
511
+ fkeys = self._get_source_foreign_keys()
512
+ # NOTE Drop all keys that link to different primary keys in the same
513
+ # table since we don't support composite keys yet:
514
+ table_pkeys: Dict[str, Set[str]] = defaultdict(set)
515
+ for fkey in fkeys:
516
+ table_pkeys[fkey.dst_table].add(fkey.primary_key)
517
+ return {
518
+ fkey.name: fkey
519
+ for fkey in fkeys if len(table_pkeys[fkey.dst_table]) == 1
520
+ }
462
521
 
463
522
  @abstractmethod
464
- def _get_source_foreign_keys(self) -> List[Tuple[str, str, str]]:
523
+ def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
465
524
  pass
466
525
 
467
- @abstractmethod
468
- def _infer_primary_key(self, candidates: List[str]) -> Optional[str]:
469
- pass
526
+ @cached_property
527
+ def _sample_df(self) -> pd.DataFrame:
528
+ return self._get_sample_df()
470
529
 
471
530
  @abstractmethod
472
- def _infer_time_column(self, candidates: List[str]) -> Optional[str]:
531
+ def _get_sample_df(self) -> pd.DataFrame:
473
532
  pass
474
533
 
475
- @abstractmethod
534
+ @cached_property
476
535
  def _num_rows(self) -> Optional[int]:
536
+ return self._get_num_rows()
537
+
538
+ @abstractmethod
539
+ def _get_num_rows(self) -> Optional[int]:
477
540
  pass
@@ -2,7 +2,9 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
+ from dataclasses import dataclass, field
5
6
  from importlib.util import find_spec
7
+ from pathlib import Path
6
8
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
9
 
8
10
  import pandas as pd
@@ -14,9 +16,18 @@ from typing_extensions import Self
14
16
  from kumoai import in_notebook
15
17
  from kumoai.experimental.rfm import Table
16
18
  from kumoai.graph import Edge
19
+ from kumoai.mixin import CastMixin
17
20
 
18
21
  if TYPE_CHECKING:
19
22
  import graphviz
23
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
24
+ from snowflake.connector import SnowflakeConnection
25
+
26
+
27
+ @dataclass
28
+ class SqliteConnectionConfig(CastMixin):
29
+ uri: Union[str, Path]
30
+ kwargs: Dict[str, Any] = field(default_factory=dict)
20
31
 
21
32
 
22
33
  class Graph:
@@ -86,14 +97,17 @@ class Graph:
86
97
  self.add_table(table)
87
98
 
88
99
  for table in tables:
89
- for fkey, dst_table, pkey in table._get_source_foreign_keys():
90
- if self[dst_table].primary_key is None:
91
- self[dst_table].primary_key = pkey
92
- elif self[dst_table]._primary_key != pkey:
100
+ for fkey in table._source_foreign_key_dict.values():
101
+ if fkey.name not in table or fkey.dst_table not in self:
102
+ continue
103
+ if self[fkey.dst_table].primary_key is None:
104
+ self[fkey.dst_table].primary_key = fkey.primary_key
105
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
93
106
  raise ValueError(f"Found duplicate primary key definition "
94
- f"'{self[dst_table]._primary_key}' and "
95
- f"'{pkey}' in table '{dst_table}'.")
96
- self.link(table.name, fkey, dst_table)
107
+ f"'{self[fkey.dst_table]._primary_key}' "
108
+ f"and '{fkey.primary_key}' in table "
109
+ f"'{fkey.dst_table}'.")
110
+ self.link(table.name, fkey.name, fkey.dst_table)
97
111
 
98
112
  for edge in (edges or []):
99
113
  _edge = Edge._cast(edge)
@@ -132,13 +146,6 @@ class Graph:
132
146
  ... "table3": df3,
133
147
  ... })
134
148
 
135
- >>> # Inspect table metadata:
136
- >>> for table in graph.tables.values():
137
- ... table.print_metadata()
138
-
139
- >>> # Visualize graph (if graphviz is installed):
140
- >>> graph.visualize()
141
-
142
149
  Args:
143
150
  df_dict: A dictionary of data frames, where the keys are the names
144
151
  of the tables and the values hold table data.
@@ -169,12 +176,17 @@ class Graph:
169
176
  @classmethod
170
177
  def from_sqlite(
171
178
  cls,
172
- uri: Any,
179
+ connection: Union[
180
+ 'AdbcSqliteConnection',
181
+ SqliteConnectionConfig,
182
+ str,
183
+ Path,
184
+ Dict[str, Any],
185
+ ],
173
186
  table_names: Optional[Sequence[str]] = None,
174
187
  edges: Optional[Sequence[Edge]] = None,
175
188
  infer_metadata: bool = True,
176
189
  verbose: bool = True,
177
- conn_kwargs: Optional[Dict[str, Any]] = None,
178
190
  ) -> Self:
179
191
  r"""Creates a :class:`Graph` from a :class:`sqlite` database.
180
192
 
@@ -188,16 +200,10 @@ class Graph:
188
200
  >>> # Create a graph from a SQLite database:
189
201
  >>> graph = rfm.Graph.from_sqlite('data.db')
190
202
 
191
- >>> # Inspect table metadata:
192
- >>> for table in graph.tables.values():
193
- ... table.print_metadata()
194
-
195
- >>> # Visualize graph (if graphviz is installed):
196
- >>> graph.visualize()
197
-
198
203
  Args:
199
- uri: The path to the database file or an open connection obtained
200
- from :meth:`~kumoai.experimental.rfm.backend.sqlite.connect`.
204
+ connection: An open connection from
205
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
+ path to the database file.
201
207
  table_names: Set of table names to include. If ``None``, will add
202
208
  all tables present in the database.
203
209
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
@@ -206,8 +212,6 @@ class Graph:
206
212
  infer_metadata: Whether to infer metadata for all tables in the
207
213
  graph.
208
214
  verbose: Whether to print verbose output.
209
- conn_kwargs: Additional connection arguments, following the
210
- :class:`adbc_driver_sqlite` protocol.
211
215
  """
212
216
  from kumoai.experimental.rfm.backend.sqlite import (
213
217
  Connection,
@@ -215,10 +219,11 @@ class Graph:
215
219
  connect,
216
220
  )
217
221
 
218
- if not isinstance(uri, Connection):
219
- connection = connect(uri, **(conn_kwargs or {}))
220
- else:
221
- connection = uri
222
+ if not isinstance(connection, Connection):
223
+ connection = SqliteConnectionConfig._cast(connection)
224
+ assert isinstance(connection, SqliteConnectionConfig)
225
+ connection = connect(connection.uri, **connection.kwargs)
226
+ assert isinstance(connection, Connection)
222
227
 
223
228
  if table_names is None:
224
229
  with connection.cursor() as cursor:
@@ -242,6 +247,140 @@ class Graph:
242
247
 
243
248
  return graph
244
249
 
250
+ @classmethod
251
+ def from_snowflake(
252
+ cls,
253
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
+ table_names: Optional[Sequence[str]] = None,
255
+ edges: Optional[Sequence[Edge]] = None,
256
+ infer_metadata: bool = True,
257
+ verbose: bool = True,
258
+ ) -> Self:
259
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
260
+ schema.
261
+
262
+ Automatically infers table metadata and links by default.
263
+
264
+ .. code-block:: python
265
+
266
+ >>> # doctest: +SKIP
267
+ >>> import kumoai.experimental.rfm as rfm
268
+
269
+ >>> # Create a graph directly in a Snowflake notebook:
270
+ >>> graph = rfm.Graph.from_snowflake()
271
+
272
+ Args:
273
+ connection: An open connection from
274
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
275
+ :class:`snowflake` connector keyword arguments to open a new
276
+ connection. If ``None``, will re-use an active session in case
277
+ it exists, or create a new connection from credentials stored
278
+ in environment variables.
279
+ table_names: Set of table names to include. If ``None``, will add
280
+ all tables present in the database.
281
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
282
+ add to the graph. If not provided, edges will be automatically
283
+ inferred from the data in case ``infer_metadata=True``.
284
+ infer_metadata: Whether to infer metadata for all tables in the
285
+ graph.
286
+ verbose: Whether to print verbose output.
287
+ """
288
+ from kumoai.experimental.rfm.backend.snow import (
289
+ Connection,
290
+ SnowTable,
291
+ connect,
292
+ )
293
+
294
+ if not isinstance(connection, Connection):
295
+ connection = connect(**(connection or {}))
296
+ assert isinstance(connection, Connection)
297
+
298
+ if table_names is None:
299
+ with connection.cursor() as cursor:
300
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
301
+ database, schema = cursor.fetchone()
302
+ query = f"""
303
+ SELECT TABLE_NAME
304
+ FROM {database}.INFORMATION_SCHEMA.TABLES
305
+ WHERE TABLE_SCHEMA = '{schema}'
306
+ """
307
+ cursor.execute(query)
308
+ table_names = [row[0] for row in cursor.fetchall()]
309
+
310
+ tables = [SnowTable(connection, name) for name in table_names]
311
+
312
+ graph = cls(tables, edges=edges or [])
313
+
314
+ if infer_metadata:
315
+ graph.infer_metadata(False)
316
+
317
+ if edges is None:
318
+ graph.infer_links(False)
319
+
320
+ if verbose:
321
+ graph.print_metadata()
322
+ graph.print_links()
323
+
324
+ return graph
325
+
326
+ @classmethod
327
+ def from_snowflake_semantic_view(
328
+ cls,
329
+ semantic_view_name: str,
330
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
331
+ verbose: bool = True,
332
+ ) -> Self:
333
+ import yaml
334
+
335
+ from kumoai.experimental.rfm.backend.snow import (
336
+ Connection,
337
+ SnowTable,
338
+ connect,
339
+ )
340
+
341
+ if not isinstance(connection, Connection):
342
+ connection = connect(**(connection or {}))
343
+ assert isinstance(connection, Connection)
344
+
345
+ with connection.cursor() as cursor:
346
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
347
+ f"'{semantic_view_name}')")
348
+ view = yaml.safe_load(cursor.fetchone()[0])
349
+
350
+ graph = cls(tables=[])
351
+
352
+ for table_desc in view['tables']:
353
+ primary_key: Optional[str] = None
354
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
355
+ and len(table_desc['primary_key']['columns']) == 1):
356
+ primary_key = table_desc['primary_key']['columns'][0]
357
+
358
+ table = SnowTable(
359
+ connection,
360
+ name=table_desc['base_table']['table'],
361
+ database=table_desc['base_table']['database'],
362
+ schema=table_desc['base_table']['schema'],
363
+ primary_key=primary_key,
364
+ )
365
+ graph.add_table(table)
366
+
367
+ # TODO Find a solution to register time columns!
368
+
369
+ for relations in view['relationships']:
370
+ if len(relations['relationship_columns']) != 1:
371
+ continue # NOTE No composite keys yet.
372
+ graph.link(
373
+ src_table=relations['left_table'],
374
+ fkey=relations['relationship_columns'][0]['left_column'],
375
+ dst_table=relations['right_table'],
376
+ )
377
+
378
+ if verbose:
379
+ graph.print_metadata()
380
+ graph.print_links()
381
+
382
+ return graph
383
+
245
384
  # Tables ##############################################################
246
385
 
247
386
  def has_table(self, name: str) -> bool:
@@ -612,10 +751,9 @@ class Graph:
612
751
  score += 1.0
613
752
 
614
753
  # Cardinality ratio:
615
- src_num_rows = src_table._num_rows()
616
- dst_num_rows = dst_table._num_rows()
617
- if (src_num_rows is not None and dst_num_rows is not None
618
- and src_num_rows > dst_num_rows):
754
+ if (src_table._num_rows is not None
755
+ and dst_table._num_rows is not None
756
+ and src_table._num_rows > dst_table._num_rows):
619
757
  score += 1.0
620
758
 
621
759
  if score < 5.0:
@@ -1,13 +1,17 @@
1
+ from .dtype import infer_dtype
2
+ from .pkey import infer_primary_key
3
+ from .time_col import infer_time_column
1
4
  from .id import contains_id
2
5
  from .timestamp import contains_timestamp
3
6
  from .categorical import contains_categorical
4
7
  from .multicategorical import contains_multicategorical
5
- from .stype import infer_stype
6
8
 
7
9
  __all__ = [
10
+ 'infer_dtype',
11
+ 'infer_primary_key',
12
+ 'infer_time_column',
8
13
  'contains_id',
9
14
  'contains_timestamp',
10
15
  'contains_categorical',
11
16
  'contains_multicategorical',
12
- 'infer_stype',
13
17
  ]
@@ -0,0 +1,79 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyarrow as pa
6
+ from kumoapi.typing import Dtype
7
+
8
+ PANDAS_TO_DTYPE: Dict[str, Dtype] = {
9
+ 'bool': Dtype.bool,
10
+ 'boolean': Dtype.bool,
11
+ 'int8': Dtype.int,
12
+ 'int16': Dtype.int,
13
+ 'int32': Dtype.int,
14
+ 'int64': Dtype.int,
15
+ 'float16': Dtype.float,
16
+ 'float32': Dtype.float,
17
+ 'float64': Dtype.float,
18
+ 'object': Dtype.string,
19
+ 'string': Dtype.string,
20
+ 'string[python]': Dtype.string,
21
+ 'string[pyarrow]': Dtype.string,
22
+ 'binary': Dtype.binary,
23
+ }
24
+
25
+
26
+ def infer_dtype(ser: pd.Series) -> Dtype:
27
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
28
+
29
+ Args:
30
+ ser: A :class:`pandas.Series` to analyze.
31
+
32
+ Returns:
33
+ The data type.
34
+ """
35
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
36
+ return Dtype.date
37
+ if pd.api.types.is_timedelta64_dtype(ser.dtype):
38
+ return Dtype.timedelta
39
+ if isinstance(ser.dtype, pd.CategoricalDtype):
40
+ return Dtype.string
41
+
42
+ if (pd.api.types.is_object_dtype(ser.dtype)
43
+ and not isinstance(ser.dtype, pd.ArrowDtype)):
44
+ index = ser.iloc[:1000].first_valid_index()
45
+ if index is not None and pd.api.types.is_list_like(ser[index]):
46
+ pos = ser.index.get_loc(index)
47
+ assert isinstance(pos, int)
48
+ ser = ser.iloc[pos:pos + 1000].dropna()
49
+ arr = pa.array(ser.tolist())
50
+ ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
+
52
+ if isinstance(ser.dtype, pd.ArrowDtype):
53
+ if pa.types.is_list(ser.dtype.pyarrow_dtype):
54
+ elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
+ if pa.types.is_integer(elem_dtype):
56
+ return Dtype.intlist
57
+ if pa.types.is_floating(elem_dtype):
58
+ return Dtype.floatlist
59
+ if pa.types.is_decimal(elem_dtype):
60
+ return Dtype.floatlist
61
+ if pa.types.is_string(elem_dtype):
62
+ return Dtype.stringlist
63
+ if pa.types.is_null(elem_dtype):
64
+ return Dtype.floatlist
65
+
66
+ if isinstance(ser.dtype, np.dtype):
67
+ dtype_str = str(ser.dtype).lower()
68
+ elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
69
+ dtype_str = ser.dtype.name.lower()
70
+ dtype_str = dtype_str.split('[')[0] # Remove backend metadata
71
+ elif isinstance(ser.dtype, pa.DataType):
72
+ dtype_str = str(ser.dtype).lower()
73
+ else:
74
+ dtype_str = 'object'
75
+
76
+ if dtype_str not in PANDAS_TO_DTYPE:
77
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
78
+
79
+ return PANDAS_TO_DTYPE[dtype_str]
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import pandas as pd
6
6
 
7
7
 
8
- def detect_primary_key(
8
+ def infer_primary_key(
9
9
  table_name: str,
10
10
  df: pd.DataFrame,
11
11
  candidates: list[str],
@@ -14,7 +14,7 @@ def detect_primary_key(
14
14
 
15
15
  Args:
16
16
  table_name: The table name.
17
- df: The pandas DataFrame to analyze
17
+ df: The pandas DataFrame to analyze.
18
18
  candidates: A list of potential candidates.
19
19
 
20
20
  Returns:
@@ -124,102 +124,3 @@ def detect_primary_key(
124
124
  f"key for this table manually.")
125
125
 
126
126
  return None
127
-
128
-
129
- def detect_time_column(
130
- df: pd.DataFrame,
131
- candidates: list[str],
132
- ) -> Optional[str]:
133
- r"""Auto-detect potential time column.
134
-
135
- Args:
136
- df: The pandas DataFrame to analyze
137
- candidates: A list of potential candidates.
138
-
139
- Returns:
140
- The name of the detected time column, or ``None`` if not found.
141
- """
142
- candidates = [ # Exclude all candidates with `*last*` in column names:
143
- col_name for col_name in candidates
144
- if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
145
- ]
146
-
147
- if len(candidates) == 0:
148
- return None
149
-
150
- if len(candidates) == 1:
151
- return candidates[0]
152
-
153
- # If there exists a dedicated `create*` column, use it as time column:
154
- create_candidates = [
155
- candidate for candidate in candidates
156
- if candidate.lower().startswith('create')
157
- ]
158
- if len(create_candidates) == 1:
159
- return create_candidates[0]
160
- if len(create_candidates) > 1:
161
- candidates = create_candidates
162
-
163
- # Find the most optimal time column. Usually, it is the one pointing to
164
- # the oldest timestamps:
165
- with warnings.catch_warnings():
166
- warnings.filterwarnings('ignore', message='Could not infer format')
167
- min_timestamp_dict = {
168
- key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
169
- for key in candidates
170
- }
171
- min_timestamp_dict = {
172
- key: value.min().tz_localize(None)
173
- for key, value in min_timestamp_dict.items()
174
- }
175
- min_timestamp_dict = {
176
- key: value
177
- for key, value in min_timestamp_dict.items() if not pd.isna(value)
178
- }
179
-
180
- if len(min_timestamp_dict) == 0:
181
- return None
182
-
183
- return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
184
-
185
-
186
- PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
187
- MULTISPACE = re.compile(r"\s+")
188
-
189
-
190
- def normalize_text(
191
- ser: pd.Series,
192
- max_words: Optional[int] = 50,
193
- ) -> pd.Series:
194
- r"""Normalizes text into a list of lower-case words.
195
-
196
- Args:
197
- ser: The :class:`pandas.Series` to normalize.
198
- max_words: The maximum number of words to return.
199
- This will auto-shrink any large text column to avoid blowing up
200
- context size.
201
- """
202
- if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
203
- return ser
204
-
205
- def normalize_fn(line: str) -> list[str]:
206
- line = PUNCTUATION.sub(" ", line)
207
- line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
208
- line = MULTISPACE.sub(" ", line)
209
- words = line.split()
210
- if max_words is not None:
211
- words = words[:max_words]
212
- return words
213
-
214
- ser = ser.fillna('').astype(str)
215
-
216
- if max_words is not None:
217
- # We estimate the number of words as 5 characters + 1 space in an
218
- # English text on average. We need this pre-filter here, as word
219
- # splitting on a giant text can be very expensive:
220
- ser = ser.str[:6 * max_words]
221
-
222
- ser = ser.str.lower()
223
- ser = ser.map(normalize_fn)
224
-
225
- return ser