kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,6 @@ from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
9
  from kumoai.experimental.rfm import Graph, LocalTable
10
- from kumoai.experimental.rfm.utils import normalize_text
11
10
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
12
11
 
13
12
  try:
@@ -21,7 +20,6 @@ class LocalGraphStore:
21
20
  def __init__(
22
21
  self,
23
22
  graph: Graph,
24
- preprocess: bool = False,
25
23
  verbose: Union[bool, ProgressLogger] = True,
26
24
  ) -> None:
27
25
 
@@ -32,7 +30,7 @@ class LocalGraphStore:
32
30
  )
33
31
 
34
32
  with verbose as logger:
35
- self.df_dict, self.mask_dict = self.sanitize(graph, preprocess)
33
+ self.df_dict, self.mask_dict = self.sanitize(graph)
36
34
  self.stype_dict = self.get_stype_dict(graph)
37
35
  logger.log("Sanitized input data")
38
36
 
@@ -106,7 +104,6 @@ class LocalGraphStore:
106
104
  def sanitize(
107
105
  self,
108
106
  graph: Graph,
109
- preprocess: bool = False,
110
107
  ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
111
108
  r"""Sanitizes raw data according to table schema definition:
112
109
 
@@ -115,10 +112,6 @@ class LocalGraphStore:
115
112
  * drops timezone information from timestamps
116
113
  * drops duplicate primary keys
117
114
  * removes rows with missing primary keys or time values
118
-
119
- If ``preprocess`` is set to ``True``, it will additionally pre-process
120
- data for faster model processing. In particular, it:
121
- * tokenizes any text column that is not a foreign key
122
115
  """
123
116
  df_dict: Dict[str, pd.DataFrame] = {}
124
117
  for table_name, table in graph.tables.items():
@@ -126,8 +119,6 @@ class LocalGraphStore:
126
119
  df = table._data
127
120
  df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
128
121
 
129
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
130
-
131
122
  mask_dict: Dict[str, np.ndarray] = {}
132
123
  for table in graph.tables.values():
133
124
  for col in table.columns:
@@ -145,12 +136,6 @@ class LocalGraphStore:
145
136
  ser = ser.dt.tz_localize(None)
146
137
  df_dict[table.name][col.name] = ser
147
138
 
148
- # Normalize text in advance (but exclude foreign keys):
149
- if (preprocess and col.stype == Stype.text
150
- and (table.name, col.name) not in foreign_keys):
151
- ser = df_dict[table.name][col.name]
152
- df_dict[table.name][col.name] = normalize_text(ser)
153
-
154
139
  mask: Optional[np.ndarray] = None
155
140
  if table._time_column is not None:
156
141
  ser = df_dict[table.name][table._time_column]
@@ -150,26 +150,16 @@ class KumoRFM:
150
150
 
151
151
  Args:
152
152
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
153
  verbose: Whether to print verbose output.
163
154
  """
164
155
  def __init__(
165
156
  self,
166
157
  graph: Graph,
167
- preprocess: bool = False,
168
158
  verbose: Union[bool, ProgressLogger] = True,
169
159
  ) -> None:
170
160
  graph = graph.validate()
171
161
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
162
+ self._graph_store = LocalGraphStore(graph, verbose)
173
163
  self._graph_sampler = LocalGraphSampler(self._graph_store)
174
164
 
175
165
  self._client: Optional[RFMAPI] = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.13.0.dev202512011731
3
+ Version: 2.13.0.dev202512031731
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -40,6 +40,8 @@ Requires-Dist: pytest-mock; extra == "test"
40
40
  Requires-Dist: requests-mock; extra == "test"
41
41
  Provides-Extra: sqlite
42
42
  Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
43
+ Provides-Extra: snowflake
44
+ Requires-Dist: snowflake-connector-python; extra == "snowflake"
43
45
  Provides-Extra: sagemaker
44
46
  Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
45
47
  Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
@@ -1,6 +1,6 @@
1
1
  kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
2
2
  kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
3
- kumoai/_version.py,sha256=q8Zwyfwa6Ha3rL-zZFmN3WuqlLR8mLt6YklIO3ZtJg8,39
3
+ kumoai/_version.py,sha256=5E8jDfy-Cd90GKsXB2iph05yeJqiO4NclFrisgQkb80,39
4
4
  kumoai/kumolib.cpython-312-darwin.so,sha256=xQvdWHx9xmQ11y3F3ywxJv6A0sDk6D3-2fQbxSdM1z4,232576
5
5
  kumoai/__init__.py,sha256=L3yOOtpSdwe3PYQlJBLkiQd3Ypp8iB5ChXkzprk3Si4,10546
6
6
  kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
@@ -11,30 +11,35 @@ kumoai/databricks.py,sha256=e6E4lOFvZHXFwh4CO1kXU1zzDU3AapLQYMxjiHPC-HQ,476
11
11
  kumoai/spcs.py,sha256=N31d7rLa-bgYh8e2J4YzX1ScxGLqiVXrqJnCl1y4Mts,4139
12
12
  kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
13
13
  kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- kumoai/experimental/rfm/local_graph_sampler.py,sha256=5DbhL9h0usFKSJfnx7HjLMPcG54qwJ48M2tmONqxXyY,6672
14
+ kumoai/experimental/rfm/local_graph_sampler.py,sha256=QmyEw1M6CsftQpwwGMaUgPog7lt4EUfp5Y_KXIOK_oo,7887
15
15
  kumoai/experimental/rfm/local_pquery_driver.py,sha256=aO7Jfwx9gxGKYvpqxZx1LLWdI1MhuZQOPtAITxoOQO0,26162
16
- kumoai/experimental/rfm/graph.py,sha256=Ff_9-rOJRSDd4bnx73CkfQjAxU4fqzzQ3CIkZOYjR8c,30729
16
+ kumoai/experimental/rfm/graph.py,sha256=5TZVbd4agFePPSazgqViAqWmLMpxHuDsX_DqHnqaNnM,36581
17
17
  kumoai/experimental/rfm/__init__.py,sha256=slliYcrh80xPtQQ_nnsp3ny9IbmHCyirmdZUfKTdME4,6064
18
- kumoai/experimental/rfm/utils.py,sha256=3IiBvT_aLBkkcJh3H11_50yt_XlEzHR0cm9Kprrtl8k,11123
19
18
  kumoai/experimental/rfm/sagemaker.py,sha256=_hTrFg4qfXe7uzwqSEG_wze-IFkwn7qde9OpUodCpbc,4982
20
- kumoai/experimental/rfm/rfm.py,sha256=0i6UA6Ds72midVq4ngaP4y5cnxs_GrwBF93Y0pFCX5k,48304
21
- kumoai/experimental/rfm/local_graph_store.py,sha256=2ehCC7bttSp8cKLe5_6GaMUdF4fbYaTR0rYPVQb6rEQ,13891
19
+ kumoai/experimental/rfm/rfm.py,sha256=FZRrYK9uoH4IoGI1hQunORp1zrpfeyi8dDqikt6Gfpk,47703
20
+ kumoai/experimental/rfm/local_graph_store.py,sha256=l6HMQBNdSdDEL0xIGhTcmR3E_JOIfZJPHDbiD0E7GlA,13140
22
21
  kumoai/experimental/rfm/authenticate.py,sha256=FiuHMvP7V3zBZUlHMDMbNLhc-UgDZgz4hjVSTuQ7DRw,18888
23
22
  kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=wQSH6esLny3EJID-fl__fxpiiG5J5Pj2EgIfbD0gx5I,604
25
- kumoai/experimental/rfm/backend/sqlite/table.py,sha256=WAUOEDDNNAMPlbN6W3ZsSfd8OLhrUT5yRa0ypde0iRs,4627
23
+ kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=jYmZDNAVsojuPO1Q5idFmG5N0aCB8BDyrpAoS31n9bc,844
24
+ kumoai/experimental/rfm/backend/sqlite/table.py,sha256=fnw6whxUmzjycFatlHqwVP64tujNY8RE20ZnAaZ9TJc,3417
26
25
  kumoai/experimental/rfm/backend/local/__init__.py,sha256=9rupbsPadaOqrEInv2nh9KEQ9mK8dSkbteMXwZmsGbU,896
27
- kumoai/experimental/rfm/backend/local/table.py,sha256=wxYmWCy_F3j2TQ5trN7Cn2hAkVDgIcSWMJa-uFkl8G0,5140
26
+ kumoai/experimental/rfm/backend/local/table.py,sha256=wQU28OX6-vtdBvrHcoRt8XTBDScSahVIql-evINkS6Y,3014
27
+ kumoai/experimental/rfm/backend/snow/__init__.py,sha256=B-tG-p8WA-mBuwvK1f0S2gdRPEGwApdxlnyeVSnY2xg,927
28
+ kumoai/experimental/rfm/backend/snow/table.py,sha256=vUqUJUfphmZIm7h1a8X0IvX-jf_wT1Oh3YRuhqT_7M8,3460
28
29
  kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
29
30
  kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=kiBJq7uVGbasG7TiqsubEl6ey3UYzZiM4bwxILqp_54,18487
30
31
  kumoai/experimental/rfm/pquery/executor.py,sha256=f7-pJhL0BgFU9E4o4gQpQyArOvyrZtwxFmks34-QOAE,2741
31
32
  kumoai/experimental/rfm/infer/multicategorical.py,sha256=0-cLpDnGryhr76QhZNO-klKokJ6MUSfxXcGdQ61oykY,1102
32
33
  kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
34
+ kumoai/experimental/rfm/infer/time_col.py,sha256=7R5Itl8RRBOr61qLpRTanIqrUVZFZcAXzDA9lCw4nx4,1820
35
+ kumoai/experimental/rfm/infer/pkey.py,sha256=ubNqW1LIjLKiXbjXELAY3g6n2f3u2Eis_uC2DEiXFiU,4393
33
36
  kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwLdaYa_Q,908
34
- kumoai/experimental/rfm/infer/__init__.py,sha256=xQ8_SuejIzXyn2J7bIKX3pXumFtRuEfBtE5oEDUDJjI,293
37
+ kumoai/experimental/rfm/infer/dtype.py,sha256=IYhLyf4UoPZ-qqcUIt-enydRTnnNqY-sSim56V7uuUU,2979
38
+ kumoai/experimental/rfm/infer/__init__.py,sha256=krdMFN8iKZlSFOl-M5MW1KuSviQV3H1E18jj2uB8g6Q,469
35
39
  kumoai/experimental/rfm/infer/timestamp.py,sha256=vM9--7eStzaGG13Y-oLYlpNJyhL6f9dp17HDXwtl_DM,1094
36
- kumoai/experimental/rfm/base/__init__.py,sha256=TlK19Ge4xTeuEMsZvSTA3uR2S0xnlYcJxXwWos-N4lw,94
37
- kumoai/experimental/rfm/base/table.py,sha256=YSjnsyn-5sX3IcCUgwed7JON7OaDH6n76PlX5GIgtVI,16946
40
+ kumoai/experimental/rfm/base/__init__.py,sha256=-f3Ap-eUG1_JIX6NwRTZ2E3Rn0KTwt_PRYz8UcajWvg,189
41
+ kumoai/experimental/rfm/base/table.py,sha256=CLC66JMBSJcvtvF8lecZywK-50_EzDHN6dc9ZekzpV0,19573
42
+ kumoai/experimental/rfm/base/source.py,sha256=8_waFQVsctryHkm9BwmFZ9-vw5cXAXfjk7KDmcl_kic,272
38
43
  kumoai/experimental/rfm/base/column.py,sha256=izCJmufJcd1RSi-ptFMfrue-JYag38MJxizka7ya0-A,2319
39
44
  kumoai/encoder/__init__.py,sha256=VPGs4miBC_WfwWeOXeHhFomOUocERFavhKf5fqITcds,182
40
45
  kumoai/graph/graph.py,sha256=iyp4klPIMn2ttuEqMJvsrxKb_tmz_DTnvziIhCegduM,38291
@@ -99,8 +104,8 @@ kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06o
99
104
  kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
100
105
  kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
101
106
  kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
102
- kumoai-2.13.0.dev202512011731.dist-info/RECORD,,
103
- kumoai-2.13.0.dev202512011731.dist-info/WHEEL,sha256=V1loQ6TpxABu1APUg0MoTRBOzSKT5xVc3skizX-ovCU,136
104
- kumoai-2.13.0.dev202512011731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
105
- kumoai-2.13.0.dev202512011731.dist-info/METADATA,sha256=q3ILcIOLhew8MHoXzkQ_z59UwQouGD6mk7U8jEpNGSQ,2376
106
- kumoai-2.13.0.dev202512011731.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
107
+ kumoai-2.13.0.dev202512031731.dist-info/RECORD,,
108
+ kumoai-2.13.0.dev202512031731.dist-info/WHEEL,sha256=V1loQ6TpxABu1APUg0MoTRBOzSKT5xVc3skizX-ovCU,136
109
+ kumoai-2.13.0.dev202512031731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
110
+ kumoai-2.13.0.dev202512031731.dist-info/METADATA,sha256=M7NK6i4Wz55zU6rHvkhts9ewT2dMeXtMIYnH5of8U-o,2466
111
+ kumoai-2.13.0.dev202512031731.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
@@ -1,344 +0,0 @@
1
- import re
2
- import warnings
3
- from typing import Any, Dict, Optional
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import pyarrow as pa
8
- from kumoapi.typing import Dtype, Stype
9
-
10
- from kumoai.experimental.rfm.infer import (
11
- contains_categorical,
12
- contains_id,
13
- contains_multicategorical,
14
- contains_timestamp,
15
- )
16
-
17
- # Mapping from pandas/numpy dtypes to Kumo Dtypes
18
- PANDAS_TO_DTYPE: Dict[Any, Dtype] = {
19
- np.dtype('bool'): Dtype.bool,
20
- pd.BooleanDtype(): Dtype.bool,
21
- pa.bool_(): Dtype.bool,
22
- np.dtype('byte'): Dtype.int,
23
- pd.UInt8Dtype(): Dtype.int,
24
- np.dtype('int16'): Dtype.int,
25
- pd.Int16Dtype(): Dtype.int,
26
- np.dtype('int32'): Dtype.int,
27
- pd.Int32Dtype(): Dtype.int,
28
- np.dtype('int64'): Dtype.int,
29
- pd.Int64Dtype(): Dtype.int,
30
- np.dtype('float32'): Dtype.float,
31
- pd.Float32Dtype(): Dtype.float,
32
- np.dtype('float64'): Dtype.float,
33
- pd.Float64Dtype(): Dtype.float,
34
- np.dtype('object'): Dtype.string,
35
- pd.StringDtype(storage='python'): Dtype.string,
36
- pd.StringDtype(storage='pyarrow'): Dtype.string,
37
- pa.string(): Dtype.string,
38
- pa.binary(): Dtype.binary,
39
- np.dtype('datetime64[ns]'): Dtype.date,
40
- np.dtype('timedelta64[ns]'): Dtype.timedelta,
41
- pa.list_(pa.float32()): Dtype.floatlist,
42
- pa.list_(pa.int64()): Dtype.intlist,
43
- pa.list_(pa.string()): Dtype.stringlist,
44
- }
45
-
46
-
47
- def to_dtype(ser: pd.Series) -> Dtype:
48
- """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
49
-
50
- Args:
51
- ser: A :class:`pandas.Series` to analyze.
52
-
53
- Returns:
54
- The data type.
55
- """
56
- if pd.api.types.is_datetime64_any_dtype(ser.dtype):
57
- return Dtype.date
58
-
59
- if isinstance(ser.dtype, pd.CategoricalDtype):
60
- return Dtype.string
61
-
62
- if pd.api.types.is_object_dtype(ser.dtype):
63
- index = ser.iloc[:1000].first_valid_index()
64
- if index is not None and pd.api.types.is_list_like(ser[index]):
65
- pos = ser.index.get_loc(index)
66
- assert isinstance(pos, int)
67
- ser = ser.iloc[pos:pos + 1000].dropna()
68
-
69
- if not ser.map(pd.api.types.is_list_like).all():
70
- raise ValueError("Data contains a mix of list-like and "
71
- "non-list-like values")
72
-
73
- ser = ser[ser.map(lambda x: not isinstance(x, list) or len(x) > 0)]
74
-
75
- dtypes = ser.apply(lambda x: PANDAS_TO_DTYPE.get(
76
- np.array(x).dtype, Dtype.string)).unique().tolist()
77
-
78
- invalid_dtypes = set(dtypes) - {
79
- Dtype.string,
80
- Dtype.int,
81
- Dtype.float,
82
- }
83
- if len(invalid_dtypes) > 0:
84
- raise ValueError(f"Data contains unsupported list data types: "
85
- f"{list(invalid_dtypes)}")
86
-
87
- if Dtype.string in dtypes:
88
- return Dtype.stringlist
89
-
90
- if dtypes == [Dtype.int]:
91
- return Dtype.intlist
92
-
93
- return Dtype.floatlist
94
-
95
- if ser.dtype not in PANDAS_TO_DTYPE:
96
- raise ValueError(f"Unsupported data type '{ser.dtype}'")
97
-
98
- return PANDAS_TO_DTYPE[ser.dtype]
99
-
100
-
101
- def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
102
- r"""Infers the semantic type of a column.
103
-
104
- Args:
105
- ser: A :class:`pandas.Series` to analyze.
106
- column_name: The name of the column (used for pattern matching).
107
- dtype: The data type.
108
-
109
- Returns:
110
- The semantic type.
111
- """
112
- if contains_id(ser, column_name, dtype):
113
- return Stype.ID
114
-
115
- if contains_timestamp(ser, column_name, dtype):
116
- return Stype.timestamp
117
-
118
- if contains_multicategorical(ser, column_name, dtype):
119
- return Stype.multicategorical
120
-
121
- if contains_categorical(ser, column_name, dtype):
122
- return Stype.categorical
123
-
124
- return dtype.default_stype
125
-
126
-
127
- def detect_primary_key(
128
- table_name: str,
129
- df: pd.DataFrame,
130
- candidates: list[str],
131
- ) -> Optional[str]:
132
- r"""Auto-detect potential primary key column.
133
-
134
- Args:
135
- table_name: The table name.
136
- df: The pandas DataFrame to analyze
137
- candidates: A list of potential candidates.
138
-
139
- Returns:
140
- The name of the detected primary key, or ``None`` if not found.
141
- """
142
- # A list of (potentially modified) table names that are eligible to match
143
- # with a primary key, i.e.:
144
- # - UserInfo -> User
145
- # - snakecase <-> camelcase
146
- # - camelcase <-> snakecase
147
- # - plural <-> singular (users -> user, eligibilities -> eligibility)
148
- # - verb -> noun (qualifying -> qualify)
149
- _table_names = {table_name}
150
- if table_name.lower().endswith('_info'):
151
- _table_names.add(table_name[:-5])
152
- elif table_name.lower().endswith('info'):
153
- _table_names.add(table_name[:-4])
154
-
155
- table_names = set()
156
- for _table_name in _table_names:
157
- table_names.add(_table_name.lower())
158
- snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
159
- snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
160
- table_names.add(snakecase.lower())
161
- camelcase = _table_name.replace('_', '')
162
- table_names.add(camelcase.lower())
163
- if _table_name.lower().endswith('s'):
164
- table_names.add(_table_name.lower()[:-1])
165
- table_names.add(snakecase.lower()[:-1])
166
- table_names.add(camelcase.lower()[:-1])
167
- else:
168
- table_names.add(_table_name.lower() + 's')
169
- table_names.add(snakecase.lower() + 's')
170
- table_names.add(camelcase.lower() + 's')
171
- if _table_name.lower().endswith('ies'):
172
- table_names.add(_table_name.lower()[:-3] + 'y')
173
- table_names.add(snakecase.lower()[:-3] + 'y')
174
- table_names.add(camelcase.lower()[:-3] + 'y')
175
- elif _table_name.lower().endswith('y'):
176
- table_names.add(_table_name.lower()[:-1] + 'ies')
177
- table_names.add(snakecase.lower()[:-1] + 'ies')
178
- table_names.add(camelcase.lower()[:-1] + 'ies')
179
- if _table_name.lower().endswith('ing'):
180
- table_names.add(_table_name.lower()[:-3])
181
- table_names.add(snakecase.lower()[:-3])
182
- table_names.add(camelcase.lower()[:-3])
183
-
184
- scores: list[tuple[str, int]] = []
185
- for col_name in candidates:
186
- col_name_lower = col_name.lower()
187
-
188
- score = 0
189
-
190
- if col_name_lower == 'id':
191
- score += 4
192
-
193
- for table_name_lower in table_names:
194
-
195
- if col_name_lower == table_name_lower:
196
- score += 4 # USER -> USER
197
- break
198
-
199
- for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
200
- if not col_name_lower.endswith(suffix):
201
- continue
202
-
203
- if col_name_lower == f'{table_name_lower}_{suffix}':
204
- score += 5 # USER -> USER_ID
205
- break
206
-
207
- if col_name_lower == f'{table_name_lower}{suffix}':
208
- score += 5 # User -> UserId
209
- break
210
-
211
- if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
212
- score += 2
213
-
214
- if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
215
- score += 2
216
-
217
- # `rel-bench` hard-coding :(
218
- if table_name == 'studies' and col_name == 'nct_id':
219
- score += 1
220
-
221
- ser = df[col_name].iloc[:1_000_000]
222
- score += 3 * (ser.nunique() / len(ser))
223
-
224
- scores.append((col_name, score))
225
-
226
- scores = [x for x in scores if x[-1] >= 4]
227
- scores.sort(key=lambda x: x[-1], reverse=True)
228
-
229
- if len(scores) == 0:
230
- return None
231
-
232
- if len(scores) == 1:
233
- return scores[0][0]
234
-
235
- # In case of multiple candidates, only return one if its score is unique:
236
- if scores[0][1] != scores[1][1]:
237
- return scores[0][0]
238
-
239
- max_score = max(scores, key=lambda x: x[1])
240
- candidates = [col_name for col_name, score in scores if score == max_score]
241
- warnings.warn(f"Found multiple potential primary keys in table "
242
- f"'{table_name}': {candidates}. Please specify the primary "
243
- f"key for this table manually.")
244
-
245
- return None
246
-
247
-
248
- def detect_time_column(
249
- df: pd.DataFrame,
250
- candidates: list[str],
251
- ) -> Optional[str]:
252
- r"""Auto-detect potential time column.
253
-
254
- Args:
255
- df: The pandas DataFrame to analyze
256
- candidates: A list of potential candidates.
257
-
258
- Returns:
259
- The name of the detected time column, or ``None`` if not found.
260
- """
261
- candidates = [ # Exclude all candidates with `*last*` in column names:
262
- col_name for col_name in candidates
263
- if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
264
- ]
265
-
266
- if len(candidates) == 0:
267
- return None
268
-
269
- if len(candidates) == 1:
270
- return candidates[0]
271
-
272
- # If there exists a dedicated `create*` column, use it as time column:
273
- create_candidates = [
274
- candidate for candidate in candidates
275
- if candidate.lower().startswith('create')
276
- ]
277
- if len(create_candidates) == 1:
278
- return create_candidates[0]
279
- if len(create_candidates) > 1:
280
- candidates = create_candidates
281
-
282
- # Find the most optimal time column. Usually, it is the one pointing to
283
- # the oldest timestamps:
284
- with warnings.catch_warnings():
285
- warnings.filterwarnings('ignore', message='Could not infer format')
286
- min_timestamp_dict = {
287
- key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
288
- for key in candidates
289
- }
290
- min_timestamp_dict = {
291
- key: value.min().tz_localize(None)
292
- for key, value in min_timestamp_dict.items()
293
- }
294
- min_timestamp_dict = {
295
- key: value
296
- for key, value in min_timestamp_dict.items() if not pd.isna(value)
297
- }
298
-
299
- if len(min_timestamp_dict) == 0:
300
- return None
301
-
302
- return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
303
-
304
-
305
- PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
306
- MULTISPACE = re.compile(r"\s+")
307
-
308
-
309
- def normalize_text(
310
- ser: pd.Series,
311
- max_words: Optional[int] = 50,
312
- ) -> pd.Series:
313
- r"""Normalizes text into a list of lower-case words.
314
-
315
- Args:
316
- ser: The :class:`pandas.Series` to normalize.
317
- max_words: The maximum number of words to return.
318
- This will auto-shrink any large text column to avoid blowing up
319
- context size.
320
- """
321
- if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
322
- return ser
323
-
324
- def normalize_fn(line: str) -> list[str]:
325
- line = PUNCTUATION.sub(" ", line)
326
- line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
327
- line = MULTISPACE.sub(" ", line)
328
- words = line.split()
329
- if max_words is not None:
330
- words = words[:max_words]
331
- return words
332
-
333
- ser = ser.fillna('').astype(str)
334
-
335
- if max_words is not None:
336
- # We estimate the number of words as 5 characters + 1 space in an
337
- # English text on average. We need this pre-filter here, as word
338
- # splitting on a giant text can be very expensive:
339
- ser = ser.str[:6 * max_words]
340
-
341
- ser = ser.str.lower()
342
- ser = ser.map(normalize_fn)
343
-
344
- return ser