kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,8 +2,10 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
+ from dataclasses import dataclass, field
5
6
  from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
9
 
8
10
  import pandas as pd
9
11
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -14,9 +16,18 @@ from typing_extensions import Self
14
16
  from kumoai import in_notebook
15
17
  from kumoai.experimental.rfm import Table
16
18
  from kumoai.graph import Edge
19
+ from kumoai.mixin import CastMixin
17
20
 
18
21
  if TYPE_CHECKING:
19
22
  import graphviz
23
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
24
+ from snowflake.connector import SnowflakeConnection
25
+
26
+
27
+ @dataclass
28
+ class SqliteConnectionConfig(CastMixin):
29
+ uri: Union[str, Path]
30
+ kwargs: Dict[str, Any] = field(default_factory=dict)
20
31
 
21
32
 
22
33
  class Graph:
@@ -85,10 +96,24 @@ class Graph:
85
96
  for table in tables:
86
97
  self.add_table(table)
87
98
 
99
+ for table in tables:
100
+ for fkey in table._source_foreign_key_dict.values():
101
+ if fkey.name not in table or fkey.dst_table not in self:
102
+ continue
103
+ if self[fkey.dst_table].primary_key is None:
104
+ self[fkey.dst_table].primary_key = fkey.primary_key
105
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
106
+ raise ValueError(f"Found duplicate primary key definition "
107
+ f"'{self[fkey.dst_table]._primary_key}' "
108
+ f"and '{fkey.primary_key}' in table "
109
+ f"'{fkey.dst_table}'.")
110
+ self.link(table.name, fkey.name, fkey.dst_table)
111
+
88
112
  for edge in (edges or []):
89
113
  _edge = Edge._cast(edge)
90
114
  assert _edge is not None
91
- self.link(*_edge)
115
+ if _edge not in self._edges:
116
+ self.link(*_edge)
92
117
 
93
118
  @classmethod
94
119
  def from_data(
@@ -101,7 +126,7 @@ class Graph:
101
126
  r"""Creates a :class:`Graph` from a dictionary of
102
127
  :class:`pandas.DataFrame` objects.
103
128
 
104
- Automatically infers table metadata and links.
129
+ Automatically infers table metadata and links by default.
105
130
 
106
131
  .. code-block:: python
107
132
 
@@ -121,50 +146,180 @@ class Graph:
121
146
  ... "table3": df3,
122
147
  ... })
123
148
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
149
  Args:
132
150
  df_dict: A dictionary of data frames, where the keys are the names
133
151
  of the tables and the values hold table data.
152
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
153
+ add to the graph. If not provided, edges will be automatically
154
+ inferred from the data in case ``infer_metadata=True``.
134
155
  infer_metadata: Whether to infer metadata for all tables in the
135
156
  graph.
157
+ verbose: Whether to print verbose output.
158
+ """
159
+ from kumoai.experimental.rfm.backend.local import LocalTable
160
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
161
+
162
+ graph = cls(tables, edges=edges or [])
163
+
164
+ if infer_metadata:
165
+ graph.infer_metadata(False)
166
+
167
+ if edges is None:
168
+ graph.infer_links(False)
169
+
170
+ if verbose:
171
+ graph.print_metadata()
172
+ graph.print_links()
173
+
174
+ return graph
175
+
176
+ @classmethod
177
+ def from_sqlite(
178
+ cls,
179
+ connection: Union[
180
+ 'AdbcSqliteConnection',
181
+ SqliteConnectionConfig,
182
+ str,
183
+ Path,
184
+ Dict[str, Any],
185
+ ],
186
+ table_names: Optional[Sequence[str]] = None,
187
+ edges: Optional[Sequence[Edge]] = None,
188
+ infer_metadata: bool = True,
189
+ verbose: bool = True,
190
+ ) -> Self:
191
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
192
+
193
+ Automatically infers table metadata and links by default.
194
+
195
+ .. code-block:: python
196
+
197
+ >>> # doctest: +SKIP
198
+ >>> import kumoai.experimental.rfm as rfm
199
+
200
+ >>> # Create a graph from a SQLite database:
201
+ >>> graph = rfm.Graph.from_sqlite('data.db')
202
+
203
+ Args:
204
+ connection: An open connection from
205
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
+ path to the database file.
207
+ table_names: Set of table names to include. If ``None``, will add
208
+ all tables present in the database.
136
209
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
210
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
211
+ inferred from the data in case ``infer_metadata=True``.
212
+ infer_metadata: Whether to infer metadata for all tables in the
213
+ graph.
139
214
  verbose: Whether to print verbose output.
215
+ """
216
+ from kumoai.experimental.rfm.backend.sqlite import (
217
+ Connection,
218
+ SQLiteTable,
219
+ connect,
220
+ )
221
+
222
+ if not isinstance(connection, Connection):
223
+ connection = SqliteConnectionConfig._cast(connection)
224
+ assert isinstance(connection, SqliteConnectionConfig)
225
+ connection = connect(connection.uri, **connection.kwargs)
226
+ assert isinstance(connection, Connection)
227
+
228
+ if table_names is None:
229
+ with connection.cursor() as cursor:
230
+ cursor.execute("SELECT name FROM sqlite_master "
231
+ "WHERE type='table'")
232
+ table_names = [row[0] for row in cursor.fetchall()]
233
+
234
+ tables = [SQLiteTable(connection, name) for name in table_names]
140
235
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
236
+ graph = cls(tables, edges=edges or [])
237
+
238
+ if infer_metadata:
239
+ graph.infer_metadata(False)
240
+
241
+ if edges is None:
242
+ graph.infer_links(False)
243
+
244
+ if verbose:
245
+ graph.print_metadata()
246
+ graph.print_links()
247
+
248
+ return graph
249
+
250
+ @classmethod
251
+ def from_snowflake(
252
+ cls,
253
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
+ table_names: Optional[Sequence[str]] = None,
255
+ edges: Optional[Sequence[Edge]] = None,
256
+ infer_metadata: bool = True,
257
+ verbose: bool = True,
258
+ ) -> Self:
259
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
260
+ schema.
261
+
262
+ Automatically infers table metadata and links by default.
263
+
264
+ .. code-block:: python
144
265
 
145
- Example:
146
266
  >>> # doctest: +SKIP
147
267
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.Graph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
268
+
269
+ >>> # Create a graph directly in a Snowflake notebook:
270
+ >>> graph = rfm.Graph.from_snowflake()
271
+
272
+ Args:
273
+ connection: An open connection from
274
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
275
+ :class:`snowflake` connector keyword arguments to open a new
276
+ connection. If ``None``, will re-use an active session in case
277
+ it exists, or create a new connection from credentials stored
278
+ in environment variables.
279
+ table_names: Set of table names to include. If ``None``, will add
280
+ all tables present in the database.
281
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
282
+ add to the graph. If not provided, edges will be automatically
283
+ inferred from the data in case ``infer_metadata=True``.
284
+ infer_metadata: Whether to infer metadata for all tables in the
285
+ graph.
286
+ verbose: Whether to print verbose output.
157
287
  """
158
- from kumoai.experimental.rfm import LocalTable
159
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
288
+ from kumoai.experimental.rfm.backend.snow import (
289
+ Connection,
290
+ SnowTable,
291
+ connect,
292
+ )
293
+
294
+ if not isinstance(connection, Connection):
295
+ connection = connect(**(connection or {}))
296
+ assert isinstance(connection, Connection)
297
+
298
+ if table_names is None:
299
+ with connection.cursor() as cursor:
300
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
301
+ database, schema = cursor.fetchone()
302
+ query = f"""
303
+ SELECT TABLE_NAME
304
+ FROM {database}.INFORMATION_SCHEMA.TABLES
305
+ WHERE TABLE_SCHEMA = '{schema}'
306
+ """
307
+ cursor.execute(query)
308
+ table_names = [row[0] for row in cursor.fetchall()]
309
+
310
+ tables = [SnowTable(connection, name) for name in table_names]
160
311
 
161
312
  graph = cls(tables, edges=edges or [])
162
313
 
163
314
  if infer_metadata:
164
- graph.infer_metadata(verbose)
315
+ graph.infer_metadata(False)
165
316
 
166
317
  if edges is None:
167
- graph.infer_links(verbose)
318
+ graph.infer_links(False)
319
+
320
+ if verbose:
321
+ graph.print_metadata()
322
+ graph.print_links()
168
323
 
169
324
  return graph
170
325
 
@@ -439,17 +594,13 @@ class Graph:
439
594
  return self
440
595
 
441
596
  def infer_links(self, verbose: bool = True) -> Self:
442
- r"""Infers links for the tables and adds them as edges to the graph.
597
+ r"""Infers missing links for the tables and adds them as edges to the
598
+ graph.
443
599
 
444
600
  Args:
445
601
  verbose: Whether to print verbose output.
446
-
447
- Note:
448
- This function expects graph edges to be undefined upfront.
449
602
  """
450
- if len(self.edges) > 0:
451
- warnings.warn("Cannot infer links if graph edges already exist")
452
- return self
603
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
453
604
 
454
605
  # A list of primary key candidates (+score) for every column:
455
606
  candidate_dict: dict[
@@ -474,6 +625,9 @@ class Graph:
474
625
  src_table_name = src_table.name.lower()
475
626
 
476
627
  for src_key in src_table.columns:
628
+ if (src_table.name, src_key.name) in known_edges:
629
+ continue
630
+
477
631
  if src_key == src_table.primary_key:
478
632
  continue # Cannot link to primary key.
479
633
 
@@ -539,10 +693,9 @@ class Graph:
539
693
  score += 1.0
540
694
 
541
695
  # Cardinality ratio:
542
- src_num_rows = src_table._num_rows()
543
- dst_num_rows = dst_table._num_rows()
544
- if (src_num_rows is not None and dst_num_rows is not None
545
- and src_num_rows > dst_num_rows):
696
+ if (src_table._num_rows is not None
697
+ and dst_table._num_rows is not None
698
+ and src_table._num_rows > dst_table._num_rows):
546
699
  score += 1.0
547
700
 
548
701
  if score < 5.0:
@@ -1,9 +1,15 @@
1
+ from .dtype import infer_dtype
2
+ from .pkey import infer_primary_key
3
+ from .time_col import infer_time_column
1
4
  from .id import contains_id
2
5
  from .timestamp import contains_timestamp
3
6
  from .categorical import contains_categorical
4
7
  from .multicategorical import contains_multicategorical
5
8
 
6
9
  __all__ = [
10
+ 'infer_dtype',
11
+ 'infer_primary_key',
12
+ 'infer_time_column',
7
13
  'contains_id',
8
14
  'contains_timestamp',
9
15
  'contains_categorical',
@@ -0,0 +1,90 @@
1
+ from typing import Any, Dict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyarrow as pa
6
+ from kumoapi.typing import Dtype
7
+
8
+ PANDAS_TO_DTYPE: Dict[Any, Dtype] = {
9
+ np.dtype('bool'): Dtype.bool,
10
+ pd.BooleanDtype(): Dtype.bool,
11
+ pa.bool_(): Dtype.bool,
12
+ np.dtype('byte'): Dtype.int,
13
+ pd.UInt8Dtype(): Dtype.int,
14
+ np.dtype('int16'): Dtype.int,
15
+ pd.Int16Dtype(): Dtype.int,
16
+ np.dtype('int32'): Dtype.int,
17
+ pd.Int32Dtype(): Dtype.int,
18
+ np.dtype('int64'): Dtype.int,
19
+ pd.Int64Dtype(): Dtype.int,
20
+ np.dtype('float32'): Dtype.float,
21
+ pd.Float32Dtype(): Dtype.float,
22
+ np.dtype('float64'): Dtype.float,
23
+ pd.Float64Dtype(): Dtype.float,
24
+ np.dtype('object'): Dtype.string,
25
+ pd.StringDtype(storage='python'): Dtype.string,
26
+ pd.StringDtype(storage='pyarrow'): Dtype.string,
27
+ pa.string(): Dtype.string,
28
+ pa.binary(): Dtype.binary,
29
+ np.dtype('datetime64[ns]'): Dtype.date,
30
+ np.dtype('timedelta64[ns]'): Dtype.timedelta,
31
+ pa.list_(pa.float32()): Dtype.floatlist,
32
+ pa.list_(pa.int64()): Dtype.intlist,
33
+ pa.list_(pa.string()): Dtype.stringlist,
34
+ }
35
+
36
+
37
+ def infer_dtype(ser: pd.Series) -> Dtype:
38
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
39
+
40
+ Args:
41
+ ser: A :class:`pandas.Series` to analyze.
42
+
43
+ Returns:
44
+ The data type.
45
+ """
46
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
47
+ return Dtype.date
48
+
49
+ if isinstance(ser.dtype, pd.CategoricalDtype):
50
+ return Dtype.string
51
+
52
+ if pd.api.types.is_object_dtype(ser.dtype):
53
+ index = ser.iloc[:1000].first_valid_index()
54
+ if index is not None and pd.api.types.is_list_like(ser[index]):
55
+ pos = ser.index.get_loc(index)
56
+ assert isinstance(pos, int)
57
+ ser = ser.iloc[pos:pos + 1000].dropna()
58
+
59
+ if not ser.map(pd.api.types.is_list_like).all():
60
+ raise ValueError("Data contains a mix of list-like and "
61
+ "non-list-like values")
62
+
63
+ # Remove all empty Python lists without known data type:
64
+ ser = ser[ser.map(lambda x: not isinstance(x, list) or len(x) > 0)]
65
+
66
+ # Infer unique data types in this series:
67
+ dtypes = ser.apply(lambda x: PANDAS_TO_DTYPE.get(
68
+ np.array(x).dtype, Dtype.string)).unique().tolist()
69
+
70
+ invalid_dtypes = set(dtypes) - {
71
+ Dtype.string,
72
+ Dtype.int,
73
+ Dtype.float,
74
+ }
75
+ if len(invalid_dtypes) > 0:
76
+ raise ValueError(f"Data contains unsupported list data types: "
77
+ f"{list(invalid_dtypes)}")
78
+
79
+ if Dtype.string in dtypes:
80
+ return Dtype.stringlist
81
+
82
+ if dtypes == [Dtype.int]:
83
+ return Dtype.intlist
84
+
85
+ return Dtype.floatlist
86
+
87
+ if ser.dtype not in PANDAS_TO_DTYPE:
88
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
89
+
90
+ return PANDAS_TO_DTYPE[ser.dtype]
@@ -0,0 +1,126 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_primary_key(
9
+ table_name: str,
10
+ df: pd.DataFrame,
11
+ candidates: list[str],
12
+ ) -> Optional[str]:
13
+ r"""Auto-detect potential primary key column.
14
+
15
+ Args:
16
+ table_name: The table name.
17
+ df: The pandas DataFrame to analyze.
18
+ candidates: A list of potential candidates.
19
+
20
+ Returns:
21
+ The name of the detected primary key, or ``None`` if not found.
22
+ """
23
+ # A list of (potentially modified) table names that are eligible to match
24
+ # with a primary key, i.e.:
25
+ # - UserInfo -> User
26
+ # - snakecase <-> camelcase
27
+ # - camelcase <-> snakecase
28
+ # - plural <-> singular (users -> user, eligibilities -> eligibility)
29
+ # - verb -> noun (qualifying -> qualify)
30
+ _table_names = {table_name}
31
+ if table_name.lower().endswith('_info'):
32
+ _table_names.add(table_name[:-5])
33
+ elif table_name.lower().endswith('info'):
34
+ _table_names.add(table_name[:-4])
35
+
36
+ table_names = set()
37
+ for _table_name in _table_names:
38
+ table_names.add(_table_name.lower())
39
+ snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
40
+ snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
41
+ table_names.add(snakecase.lower())
42
+ camelcase = _table_name.replace('_', '')
43
+ table_names.add(camelcase.lower())
44
+ if _table_name.lower().endswith('s'):
45
+ table_names.add(_table_name.lower()[:-1])
46
+ table_names.add(snakecase.lower()[:-1])
47
+ table_names.add(camelcase.lower()[:-1])
48
+ else:
49
+ table_names.add(_table_name.lower() + 's')
50
+ table_names.add(snakecase.lower() + 's')
51
+ table_names.add(camelcase.lower() + 's')
52
+ if _table_name.lower().endswith('ies'):
53
+ table_names.add(_table_name.lower()[:-3] + 'y')
54
+ table_names.add(snakecase.lower()[:-3] + 'y')
55
+ table_names.add(camelcase.lower()[:-3] + 'y')
56
+ elif _table_name.lower().endswith('y'):
57
+ table_names.add(_table_name.lower()[:-1] + 'ies')
58
+ table_names.add(snakecase.lower()[:-1] + 'ies')
59
+ table_names.add(camelcase.lower()[:-1] + 'ies')
60
+ if _table_name.lower().endswith('ing'):
61
+ table_names.add(_table_name.lower()[:-3])
62
+ table_names.add(snakecase.lower()[:-3])
63
+ table_names.add(camelcase.lower()[:-3])
64
+
65
+ scores: list[tuple[str, int]] = []
66
+ for col_name in candidates:
67
+ col_name_lower = col_name.lower()
68
+
69
+ score = 0
70
+
71
+ if col_name_lower == 'id':
72
+ score += 4
73
+
74
+ for table_name_lower in table_names:
75
+
76
+ if col_name_lower == table_name_lower:
77
+ score += 4 # USER -> USER
78
+ break
79
+
80
+ for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
81
+ if not col_name_lower.endswith(suffix):
82
+ continue
83
+
84
+ if col_name_lower == f'{table_name_lower}_{suffix}':
85
+ score += 5 # USER -> USER_ID
86
+ break
87
+
88
+ if col_name_lower == f'{table_name_lower}{suffix}':
89
+ score += 5 # User -> UserId
90
+ break
91
+
92
+ if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
93
+ score += 2
94
+
95
+ if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
96
+ score += 2
97
+
98
+ # `rel-bench` hard-coding :(
99
+ if table_name == 'studies' and col_name == 'nct_id':
100
+ score += 1
101
+
102
+ ser = df[col_name].iloc[:1_000_000]
103
+ score += 3 * (ser.nunique() / len(ser))
104
+
105
+ scores.append((col_name, score))
106
+
107
+ scores = [x for x in scores if x[-1] >= 4]
108
+ scores.sort(key=lambda x: x[-1], reverse=True)
109
+
110
+ if len(scores) == 0:
111
+ return None
112
+
113
+ if len(scores) == 1:
114
+ return scores[0][0]
115
+
116
+ # In case of multiple candidates, only return one if its score is unique:
117
+ if scores[0][1] != scores[1][1]:
118
+ return scores[0][0]
119
+
120
+ max_score = max(scores, key=lambda x: x[1])
121
+ candidates = [col_name for col_name, score in scores if score == max_score]
122
+ warnings.warn(f"Found multiple potential primary keys in table "
123
+ f"'{table_name}': {candidates}. Please specify the primary "
124
+ f"key for this table manually.")
125
+
126
+ return None
@@ -0,0 +1,62 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_time_column(
9
+ df: pd.DataFrame,
10
+ candidates: list[str],
11
+ ) -> Optional[str]:
12
+ r"""Auto-detect potential time column.
13
+
14
+ Args:
15
+ df: The pandas DataFrame to analyze.
16
+ candidates: A list of potential candidates.
17
+
18
+ Returns:
19
+ The name of the detected time column, or ``None`` if not found.
20
+ """
21
+ candidates = [ # Exclude all candidates with `*last*` in column names:
22
+ col_name for col_name in candidates
23
+ if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
24
+ ]
25
+
26
+ if len(candidates) == 0:
27
+ return None
28
+
29
+ if len(candidates) == 1:
30
+ return candidates[0]
31
+
32
+ # If there exists a dedicated `create*` column, use it as time column:
33
+ create_candidates = [
34
+ candidate for candidate in candidates
35
+ if candidate.lower().startswith('create')
36
+ ]
37
+ if len(create_candidates) == 1:
38
+ return create_candidates[0]
39
+ if len(create_candidates) > 1:
40
+ candidates = create_candidates
41
+
42
+ # Find the most optimal time column. Usually, it is the one pointing to
43
+ # the oldest timestamps:
44
+ with warnings.catch_warnings():
45
+ warnings.filterwarnings('ignore', message='Could not infer format')
46
+ min_timestamp_dict = {
47
+ key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
48
+ for key in candidates
49
+ }
50
+ min_timestamp_dict = {
51
+ key: value.min().tz_localize(None)
52
+ for key, value in min_timestamp_dict.items()
53
+ }
54
+ min_timestamp_dict = {
55
+ key: value
56
+ for key, value in min_timestamp_dict.items() if not pd.isna(value)
57
+ }
58
+
59
+ if len(min_timestamp_dict) == 0:
60
+ return None
61
+
62
+ return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from typing import Dict, List, Optional, Tuple
2
3
 
3
4
  import numpy as np
@@ -7,7 +8,47 @@ from kumoapi.typing import Stype
7
8
 
8
9
  import kumoai.kumolib as kumolib
9
10
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
10
- from kumoai.experimental.rfm.utils import normalize_text
11
+
12
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
13
+ MULTISPACE = re.compile(r"\s+")
14
+
15
+
16
+ def normalize_text(
17
+ ser: pd.Series,
18
+ max_words: Optional[int] = 50,
19
+ ) -> pd.Series:
20
+ r"""Normalizes text into a list of lower-case words.
21
+
22
+ Args:
23
+ ser: The :class:`pandas.Series` to normalize.
24
+ max_words: The maximum number of words to return.
25
+ This will auto-shrink any large text column to avoid blowing up
26
+ context size.
27
+ """
28
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
29
+ return ser
30
+
31
+ def normalize_fn(line: str) -> list[str]:
32
+ line = PUNCTUATION.sub(" ", line)
33
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
34
+ line = MULTISPACE.sub(" ", line)
35
+ words = line.split()
36
+ if max_words is not None:
37
+ words = words[:max_words]
38
+ return words
39
+
40
+ ser = ser.fillna('').astype(str)
41
+
42
+ if max_words is not None:
43
+ # We estimate the number of words as 5 characters + 1 space in an
44
+ # English text on average. We need this pre-filter here, as word
45
+ # splitting on a giant text can be very expensive:
46
+ ser = ser.str[:6 * max_words]
47
+
48
+ ser = ser.str.lower()
49
+ ser = ser.map(normalize_fn)
50
+
51
+ return ser
11
52
 
12
53
 
13
54
  class LocalGraphSampler: