polars-sgt 0.1.0__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polars_sgt-0.2.5/CHANGELOG.md +30 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/Cargo.lock +2 -1
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/Cargo.toml +2 -1
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/PKG-INFO +24 -3
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/README.md +22 -2
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/__init__.py +2 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/functions.py +212 -6
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/pyproject.toml +1 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/expressions.rs +5 -1
- polars_sgt-0.2.5/src/sgt_transform.rs +393 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_sgt_transform.py +7 -7
- polars_sgt-0.2.5/tests/verify_sgt.py +103 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/uv.lock +14 -0
- polars_sgt-0.1.0/src/sgt_transform.rs +0 -304
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.github/workflows/CI.yml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.gitignore +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.python-version +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.readthedocs.yaml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/CODE_OF_CONDUCT.md +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/LICENSE +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/Makefile +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/assets/.DS_Store +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/assets/polars-business.png +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/bump_version.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/API.rst +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/Makefile +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/conf.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/index.rst +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/installation.rst +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/requirements-docs.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/dprint.json +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/licenses/NUMPY_LICENSE.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/licenses/PANDAS_LICENSE.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/.mypy.ini +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/_internal.pyi +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/namespace.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/py.typed +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/ranges.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/typing.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/utils.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/requirements.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/rust-toolchain.toml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/arg_previous_greater.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/format_localized.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/lib.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/month_delta.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/timezone.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/to_julian.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/__init__.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/ceil_test.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/julian_date_test.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_benchmark.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_date_range.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_format_localized.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_is_busday.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_month_delta.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_timezone.py +0 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project will be documented in this file.
|
|
4
|
+
|
|
5
|
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|
6
|
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|
7
|
+
|
|
8
|
+
## [0.2.5] - 2026-02-04
|
|
9
|
+
|
|
10
|
+
### Added
|
|
11
|
+
- `use_tqdm` parameter to `sgt_transform_df` to control progress bar visibility.
|
|
12
|
+
- `keep_original_name` parameter to `sgt_transform_df` to optionally restore original sequence ID names.
|
|
13
|
+
- Support for multiple columns in `sequence_id_col` in `sgt_transform_df` (automatically concatenates and splits).
|
|
14
|
+
|
|
15
|
+
### Fixed
|
|
16
|
+
- `sgt_transform_df` now correctly handles `group_cols=None` by processing the entire DataFrame.
|
|
17
|
+
- `sgt_transform_df` now correctly filters subsets dynamically based on unique values of `group_cols` instead of hardcoded columns.
|
|
18
|
+
|
|
19
|
+
## [0.2.0] - 2026-02-02
|
|
20
|
+
|
|
21
|
+
### Added
|
|
22
|
+
- Parallel processing support with `rayon` for SGT transform.
|
|
23
|
+
- Support for custom output struct field names via `sequence_id_name` and `state_name` parameters.
|
|
24
|
+
|
|
25
|
+
### Changed
|
|
26
|
+
- **Major Performance Optimization**: Rewrote SGT transform to use O(n) group-based indexing instead of O(n*m) scanning. Throughput increased to ~1.4M+ records/second.
|
|
27
|
+
- **Struct Field Rename (BREAKING)**: Renamed `ngram_values` field in the output struct to `value` for consistency with current Polars version and parameter names.
|
|
28
|
+
|
|
29
|
+
### Fixed
|
|
30
|
+
- Performance bottleneck on large datasets (10M+ records).
|
|
@@ -2010,7 +2010,7 @@ dependencies = [
|
|
|
2010
2010
|
|
|
2011
2011
|
[[package]]
|
|
2012
2012
|
name = "polars_sgt"
|
|
2013
|
-
version = "0.
|
|
2013
|
+
version = "0.2.5"
|
|
2014
2014
|
dependencies = [
|
|
2015
2015
|
"chrono",
|
|
2016
2016
|
"chrono-tz",
|
|
@@ -2019,6 +2019,7 @@ dependencies = [
|
|
|
2019
2019
|
"polars-ops",
|
|
2020
2020
|
"pyo3",
|
|
2021
2021
|
"pyo3-polars",
|
|
2022
|
+
"rayon",
|
|
2022
2023
|
"serde",
|
|
2023
2024
|
]
|
|
2024
2025
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "polars_sgt"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.2.5"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
authors = ["Zedd <lytran14789@gmail.com>", "Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -19,4 +19,5 @@ chrono-tz = "0.10.4"
|
|
|
19
19
|
polars = { version = "0.52.0", features = ["strings", "timezones"]}
|
|
20
20
|
polars-ops = { version = "0.52.0", default-features = false }
|
|
21
21
|
polars-arrow = { version = "0.52.0", default-features = false }
|
|
22
|
+
rayon = "1.10"
|
|
22
23
|
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: polars-sgt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Classifier: Programming Language :: Rust
|
|
5
5
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
6
6
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
7
7
|
Requires-Dist: maturin>=1.11.5
|
|
8
8
|
Requires-Dist: polars>=1.36.1
|
|
9
9
|
Requires-Dist: pytest>=8.4.2
|
|
10
|
+
Requires-Dist: tqdm>=4.66.0
|
|
10
11
|
License-File: LICENSE
|
|
11
12
|
Summary: Sequence Graph Transform (SGT) for Polars - Transform sequential data into weighted n-gram representations
|
|
12
13
|
Author-email: Zedd <lytran14789@gmail.com>, Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>
|
|
@@ -91,10 +92,30 @@ result = df.select(
|
|
|
91
92
|
features = result.select([
|
|
92
93
|
pl.col("sgt_features").struct.field("sequence_id"),
|
|
93
94
|
pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
|
|
94
|
-
pl.col("sgt_features").struct.field("
|
|
95
|
+
pl.col("sgt_features").struct.field("value").alias("weights"),
|
|
95
96
|
]).explode(["ngrams", "weights"])
|
|
96
97
|
|
|
97
98
|
print(features)
|
|
99
|
+
|
|
100
|
+
#OR
|
|
101
|
+
result = df.select(
|
|
102
|
+
sgt.sgt_transform(
|
|
103
|
+
"session_id",
|
|
104
|
+
"event",
|
|
105
|
+
time_col="time",
|
|
106
|
+
deltatime="m", # minutes
|
|
107
|
+
kappa=3, # trigrams
|
|
108
|
+
time_penalty="inverse",
|
|
109
|
+
mode="l2",
|
|
110
|
+
alpha=0.5
|
|
111
|
+
).alias("struct_type")
|
|
112
|
+
)
|
|
113
|
+
out = (
|
|
114
|
+
result
|
|
115
|
+
.unnest("struct_type")
|
|
116
|
+
.explode(["ngram_keys", "value"])
|
|
117
|
+
.filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
|
|
118
|
+
)
|
|
98
119
|
```
|
|
99
120
|
|
|
100
121
|
### With DateTime Columns
|
|
@@ -180,7 +201,7 @@ result = (
|
|
|
180
201
|
Returns a Struct with three fields:
|
|
181
202
|
- `sequence_id`: Original sequence identifier
|
|
182
203
|
- `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
|
|
183
|
-
- `
|
|
204
|
+
- `value`: List of corresponding weights
|
|
184
205
|
|
|
185
206
|
## Additional DateTime Utilities
|
|
186
207
|
|
|
@@ -72,10 +72,30 @@ result = df.select(
|
|
|
72
72
|
features = result.select([
|
|
73
73
|
pl.col("sgt_features").struct.field("sequence_id"),
|
|
74
74
|
pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
|
|
75
|
-
pl.col("sgt_features").struct.field("
|
|
75
|
+
pl.col("sgt_features").struct.field("value").alias("weights"),
|
|
76
76
|
]).explode(["ngrams", "weights"])
|
|
77
77
|
|
|
78
78
|
print(features)
|
|
79
|
+
|
|
80
|
+
#OR
|
|
81
|
+
result = df.select(
|
|
82
|
+
sgt.sgt_transform(
|
|
83
|
+
"session_id",
|
|
84
|
+
"event",
|
|
85
|
+
time_col="time",
|
|
86
|
+
deltatime="m", # minutes
|
|
87
|
+
kappa=3, # trigrams
|
|
88
|
+
time_penalty="inverse",
|
|
89
|
+
mode="l2",
|
|
90
|
+
alpha=0.5
|
|
91
|
+
).alias("struct_type")
|
|
92
|
+
)
|
|
93
|
+
out = (
|
|
94
|
+
result
|
|
95
|
+
.unnest("struct_type")
|
|
96
|
+
.explode(["ngram_keys", "value"])
|
|
97
|
+
.filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
|
|
98
|
+
)
|
|
79
99
|
```
|
|
80
100
|
|
|
81
101
|
### With DateTime Columns
|
|
@@ -161,7 +181,7 @@ result = (
|
|
|
161
181
|
Returns a Struct with three fields:
|
|
162
182
|
- `sequence_id`: Original sequence identifier
|
|
163
183
|
- `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
|
|
164
|
-
- `
|
|
184
|
+
- `value`: List of corresponding weights
|
|
165
185
|
|
|
166
186
|
## Additional DateTime Utilities
|
|
167
187
|
|
|
@@ -11,6 +11,7 @@ from polars_sgt.functions import (
|
|
|
11
11
|
month_delta,
|
|
12
12
|
month_name,
|
|
13
13
|
sgt_transform,
|
|
14
|
+
sgt_transform_df,
|
|
14
15
|
to_julian_date,
|
|
15
16
|
to_local_datetime,
|
|
16
17
|
)
|
|
@@ -30,6 +31,7 @@ __all__ = [
|
|
|
30
31
|
"month_delta",
|
|
31
32
|
"month_name",
|
|
32
33
|
"sgt_transform",
|
|
34
|
+
"sgt_transform_df",
|
|
33
35
|
"to_julian_date",
|
|
34
36
|
"to_local_datetime",
|
|
35
37
|
]
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import sys
|
|
4
4
|
from datetime import date
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Literal
|
|
6
|
+
from typing import TYPE_CHECKING, Literal, Union, Any, Iterable
|
|
7
7
|
|
|
8
8
|
import polars as pl
|
|
9
9
|
from polars.plugins import register_plugin_function
|
|
@@ -675,7 +675,7 @@ def arg_previous_greater(expr: IntoExprColumn) -> pl.Expr:
|
|
|
675
675
|
|
|
676
676
|
|
|
677
677
|
def sgt_transform(
|
|
678
|
-
sequence_id_col: IntoExprColumn,
|
|
678
|
+
sequence_id_col: Union[IntoExprColumn, Iterable[IntoExprColumn]],
|
|
679
679
|
state_col: IntoExprColumn,
|
|
680
680
|
time_col: IntoExprColumn | None = None,
|
|
681
681
|
*,
|
|
@@ -696,7 +696,7 @@ def sgt_transform(
|
|
|
696
696
|
Parameters
|
|
697
697
|
----------
|
|
698
698
|
sequence_id_col
|
|
699
|
-
Column name containing sequence identifiers (groups)
|
|
699
|
+
Column or list of name/pl.col/pl.series containing sequence identifiers (groups)
|
|
700
700
|
state_col
|
|
701
701
|
Column name containing state/event values
|
|
702
702
|
time_col
|
|
@@ -740,7 +740,7 @@ def sgt_transform(
|
|
|
740
740
|
Struct expression containing:
|
|
741
741
|
- sequence_id: Original sequence identifier
|
|
742
742
|
- ngram_keys: List of n-gram strings
|
|
743
|
-
-
|
|
743
|
+
- value: List of corresponding weights
|
|
744
744
|
|
|
745
745
|
Examples
|
|
746
746
|
--------
|
|
@@ -821,7 +821,7 @@ def sgt_transform(
|
|
|
821
821
|
>>> df_features = result.select([
|
|
822
822
|
... pl.col("sgt_result").struct.field("sequence_id"),
|
|
823
823
|
... pl.col("sgt_result").struct.field("ngram_keys").alias("ngrams"),
|
|
824
|
-
... pl.col("sgt_result").struct.field("
|
|
824
|
+
... pl.col("sgt_result").struct.field("value").alias("weights"),
|
|
825
825
|
... ]).explode(["ngrams", "weights"])
|
|
826
826
|
|
|
827
827
|
Notes
|
|
@@ -833,7 +833,12 @@ def sgt_transform(
|
|
|
833
833
|
- Missing values in time columns are treated as 0
|
|
834
834
|
|
|
835
835
|
"""
|
|
836
|
-
|
|
836
|
+
# check if col is iterable
|
|
837
|
+
if isinstance(sequence_id_col, Iterable) and not isinstance(sequence_id_col, str):
|
|
838
|
+
sequence_id_cols = [parse_into_expr(col) for col in sequence_id_col]
|
|
839
|
+
sequence_id_col = pl.concat_str(sequence_id_cols, separator="--")
|
|
840
|
+
else:
|
|
841
|
+
sequence_id_col = parse_into_expr(sequence_id_col)
|
|
837
842
|
state_col = parse_into_expr(state_col)
|
|
838
843
|
|
|
839
844
|
if time_col is not None:
|
|
@@ -855,5 +860,206 @@ def sgt_transform(
|
|
|
855
860
|
"alpha": alpha,
|
|
856
861
|
"beta": beta,
|
|
857
862
|
"deltatime": deltatime,
|
|
863
|
+
"sequence_id_name": None,
|
|
864
|
+
"state_name": None,
|
|
858
865
|
},
|
|
859
866
|
)
|
|
867
|
+
|
|
868
|
+
from tqdm import tqdm
|
|
869
|
+
def clean_name(n:str):
|
|
870
|
+
new_n = n.strip("{}").replace('"', '').replace(',', '--').replace(" ", "").lower()
|
|
871
|
+
return new_n
|
|
872
|
+
def clean_column_name(df: pl.DataFrame)->pl.DataFrame:
|
|
873
|
+
cols = [clean_name(c) for c in df.columns]
|
|
874
|
+
return df.rename(
|
|
875
|
+
{
|
|
876
|
+
oc: c for oc, c in zip(df.columns, cols)
|
|
877
|
+
}
|
|
878
|
+
)
|
|
879
|
+
def sgt_transform_df(
|
|
880
|
+
df: Union[pl.DataFrame, pl.LazyFrame],
|
|
881
|
+
sequence_id_col: Union[IntoExprColumn, Iterable[IntoExprColumn]],
|
|
882
|
+
state_col: IntoExprColumn,
|
|
883
|
+
time_col: IntoExprColumn | None = None,
|
|
884
|
+
group_cols: Union[IntoExprColumn, Iterable[IntoExprColumn]] | None = None,
|
|
885
|
+
*,
|
|
886
|
+
kappa: int = 1,
|
|
887
|
+
length_sensitive: bool = False,
|
|
888
|
+
mode: Literal["l1", "l2", "none"] = "l1",
|
|
889
|
+
time_penalty: Literal["inverse", "exponential", "linear", "power", "none"] = "inverse",
|
|
890
|
+
alpha: float = 1.0,
|
|
891
|
+
beta: float = 2.0,
|
|
892
|
+
deltatime: Literal["s", "m", "h", "d", "w", "month", "q", "y"] | None = None,
|
|
893
|
+
group_name: str = "sgt_",
|
|
894
|
+
use_tqdm: bool = True,
|
|
895
|
+
keep_original_name: bool = True,
|
|
896
|
+
) -> Union[pl.DataFrame, dict[Any, pl.DataFrame]]:
|
|
897
|
+
"""
|
|
898
|
+
Apply SGT transform to a DataFrame, optionally grouped by columns.
|
|
899
|
+
|
|
900
|
+
Parameters
|
|
901
|
+
----------
|
|
902
|
+
df
|
|
903
|
+
Input DataFrame or LazyFrame
|
|
904
|
+
sequence_id_col
|
|
905
|
+
Column(s) identifying sequences. If multiple columns are provided,
|
|
906
|
+
they will be concatenated for processing and optionally restored.
|
|
907
|
+
state_col
|
|
908
|
+
Column containing states/events
|
|
909
|
+
time_col
|
|
910
|
+
Optional column containing timestamps
|
|
911
|
+
group_cols
|
|
912
|
+
Column(s) to group by before applying SGT.
|
|
913
|
+
If None, applies SGT to the whole DataFrame (or by existing sequence_id).
|
|
914
|
+
If provided, the DataFrame is split into subsets based on unique values of these columns.
|
|
915
|
+
kappa
|
|
916
|
+
SGT kappa parameter
|
|
917
|
+
length_sensitive
|
|
918
|
+
SGT length_sensitive parameter
|
|
919
|
+
mode
|
|
920
|
+
SGT mode parameter
|
|
921
|
+
time_penalty
|
|
922
|
+
SGT time_penalty parameter
|
|
923
|
+
alpha
|
|
924
|
+
SGT alpha parameter
|
|
925
|
+
beta
|
|
926
|
+
SGT beta parameter
|
|
927
|
+
deltatime
|
|
928
|
+
SGT deltatime parameter
|
|
929
|
+
group_name
|
|
930
|
+
Prefix for keys in the returned dictionary when group_cols is used.
|
|
931
|
+
use_tqdm
|
|
932
|
+
Whether to show a progress bar when iterating over groups.
|
|
933
|
+
keep_original_name
|
|
934
|
+
If True, and sequence_id_col was multiple columns, split the concatenated ID
|
|
935
|
+
back into original columns in the result.
|
|
936
|
+
|
|
937
|
+
Returns
|
|
938
|
+
-------
|
|
939
|
+
Union[pl.DataFrame, dict[Any, pl.DataFrame]]
|
|
940
|
+
If group_cols is None, returns a single DataFrame with SGT features.
|
|
941
|
+
If group_cols is provided, returns a dictionary where keys map to group values
|
|
942
|
+
and values are DataFrames with SGT features.
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
# Handle multiple sequence ID columns
|
|
946
|
+
is_multi_seq = False
|
|
947
|
+
original_seq_cols = []
|
|
948
|
+
|
|
949
|
+
if isinstance(sequence_id_col, (list, tuple)) and not isinstance(sequence_id_col, str):
|
|
950
|
+
is_multi_seq = True
|
|
951
|
+
original_seq_cols = [str(c) for c in sequence_id_col]
|
|
952
|
+
|
|
953
|
+
# If no grouping is requested, just run sgt_transform on the whole DF
|
|
954
|
+
if group_cols is None:
|
|
955
|
+
result = df.select(
|
|
956
|
+
sgt_transform(
|
|
957
|
+
sequence_id_col,
|
|
958
|
+
state_col,
|
|
959
|
+
time_col=time_col,
|
|
960
|
+
deltatime=deltatime,
|
|
961
|
+
kappa=kappa,
|
|
962
|
+
length_sensitive=length_sensitive,
|
|
963
|
+
mode=mode,
|
|
964
|
+
time_penalty=time_penalty,
|
|
965
|
+
alpha=alpha,
|
|
966
|
+
beta=beta,
|
|
967
|
+
).alias("struct_type")
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
out = (
|
|
971
|
+
result
|
|
972
|
+
.unnest("struct_type")
|
|
973
|
+
.explode(["ngram_keys", "value"])
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
# Pivot to get features as columns
|
|
977
|
+
# Note: sequence_id in output is named "sequence_id" from the struct
|
|
978
|
+
df_sub = out.pivot(on="ngram_keys", index="sequence_id", values="value")
|
|
979
|
+
df_sub = clean_column_name(df_sub)
|
|
980
|
+
|
|
981
|
+
if keep_original_name:
|
|
982
|
+
# Identify the sequence id column in the result
|
|
983
|
+
# It comes out as "sequence_id" from the struct
|
|
984
|
+
|
|
985
|
+
if is_multi_seq:
|
|
986
|
+
# Split the "sequence_id" column back into original cols
|
|
987
|
+
# Assuming "--" separator as used in sgt_transform
|
|
988
|
+
split_exprs = [
|
|
989
|
+
pl.col("sequence_id").str.split_exact("--", len(original_seq_cols) - 1)
|
|
990
|
+
.struct.field(f"field_{i}")
|
|
991
|
+
.alias(col_name)
|
|
992
|
+
for i, col_name in enumerate(original_seq_cols)
|
|
993
|
+
]
|
|
994
|
+
df_sub = df_sub.with_columns(split_exprs).drop("sequence_id")
|
|
995
|
+
elif isinstance(sequence_id_col, str) and sequence_id_col != "sequence_id":
|
|
996
|
+
# Rename back to original name if it's a single string col
|
|
997
|
+
df_sub = df_sub.rename({"sequence_id": sequence_id_col})
|
|
998
|
+
|
|
999
|
+
return df_sub
|
|
1000
|
+
|
|
1001
|
+
# If grouping is requested
|
|
1002
|
+
if isinstance(group_cols, str):
|
|
1003
|
+
group_cols = [group_cols]
|
|
1004
|
+
|
|
1005
|
+
# Get unique combinations of group columns
|
|
1006
|
+
subset_filters = df.select(group_cols).unique().to_dicts()
|
|
1007
|
+
|
|
1008
|
+
dfs = {}
|
|
1009
|
+
|
|
1010
|
+
iterator = tqdm(subset_filters, desc=f"Calculate SGT for each sub df in {group_name}") if use_tqdm else subset_filters
|
|
1011
|
+
|
|
1012
|
+
for i in iterator:
|
|
1013
|
+
# Create filter expression
|
|
1014
|
+
filter_expr = pl.lit(True)
|
|
1015
|
+
key_parts = []
|
|
1016
|
+
for col_name, val in i.items():
|
|
1017
|
+
filter_expr &= (pl.col(col_name) == val)
|
|
1018
|
+
key_parts.append(str(val))
|
|
1019
|
+
|
|
1020
|
+
key = f"{group_name}{'-'.join(key_parts)}"
|
|
1021
|
+
|
|
1022
|
+
dfsub = df.filter(filter_expr)
|
|
1023
|
+
|
|
1024
|
+
result = dfsub.select(
|
|
1025
|
+
sgt_transform(
|
|
1026
|
+
sequence_id_col,
|
|
1027
|
+
state_col,
|
|
1028
|
+
time_col=time_col,
|
|
1029
|
+
deltatime=deltatime,
|
|
1030
|
+
kappa=kappa,
|
|
1031
|
+
length_sensitive=length_sensitive,
|
|
1032
|
+
mode=mode,
|
|
1033
|
+
time_penalty=time_penalty,
|
|
1034
|
+
alpha=alpha,
|
|
1035
|
+
beta=beta,
|
|
1036
|
+
).alias("struct_type")
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
out = (
|
|
1040
|
+
result
|
|
1041
|
+
.unnest("struct_type")
|
|
1042
|
+
.explode(["ngram_keys", "value"])
|
|
1043
|
+
.with_columns(pl.lit(key).alias("kind"))
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
# Pivot
|
|
1047
|
+
df_sub = out.pivot(on="ngram_keys", index=["kind", "sequence_id"], values="value")
|
|
1048
|
+
df_sub = clean_column_name(df_sub)
|
|
1049
|
+
|
|
1050
|
+
if keep_original_name:
|
|
1051
|
+
if is_multi_seq:
|
|
1052
|
+
# Split the "sequence_id" column back into original cols
|
|
1053
|
+
split_exprs = [
|
|
1054
|
+
pl.col("sequence_id").str.split_exact("--", len(original_seq_cols) - 1)
|
|
1055
|
+
.struct.field(f"field_{i}")
|
|
1056
|
+
.alias(col_name)
|
|
1057
|
+
for i, col_name in enumerate(original_seq_cols)
|
|
1058
|
+
]
|
|
1059
|
+
df_sub = df_sub.with_columns(split_exprs).drop("sequence_id")
|
|
1060
|
+
elif isinstance(sequence_id_col, str) and sequence_id_col != "sequence_id":
|
|
1061
|
+
df_sub = df_sub.rename({"sequence_id": sequence_id_col})
|
|
1062
|
+
|
|
1063
|
+
dfs[key] = df_sub
|
|
1064
|
+
|
|
1065
|
+
return dfs
|
|
@@ -30,6 +30,8 @@ pub struct SgtTransformKwargs {
|
|
|
30
30
|
alpha: f64,
|
|
31
31
|
beta: f64,
|
|
32
32
|
deltatime: Option<String>,
|
|
33
|
+
sequence_id_name: Option<String>,
|
|
34
|
+
state_name: Option<String>,
|
|
33
35
|
}
|
|
34
36
|
|
|
35
37
|
pub fn to_local_datetime_output(input_fields: &[Field]) -> PolarsResult<Field> {
|
|
@@ -122,7 +124,7 @@ fn sgt_transform_output(_input_fields: &[Field]) -> PolarsResult<Field> {
|
|
|
122
124
|
let fields = vec![
|
|
123
125
|
Field::new(PlSmallStr::from_str("sequence_id"), DataType::String),
|
|
124
126
|
Field::new(PlSmallStr::from_str("ngram_keys"), DataType::List(Box::new(DataType::String))),
|
|
125
|
-
Field::new(PlSmallStr::from_str("
|
|
127
|
+
Field::new(PlSmallStr::from_str("value"), DataType::List(Box::new(DataType::Float64))),
|
|
126
128
|
];
|
|
127
129
|
Ok(Field::new(
|
|
128
130
|
PlSmallStr::from_str("sgt_result"),
|
|
@@ -141,5 +143,7 @@ fn sgt_transform(inputs: &[Series], kwargs: SgtTransformKwargs) -> PolarsResult<
|
|
|
141
143
|
kwargs.alpha,
|
|
142
144
|
kwargs.beta,
|
|
143
145
|
kwargs.deltatime.as_deref(),
|
|
146
|
+
kwargs.sequence_id_name.as_deref(),
|
|
147
|
+
kwargs.state_name.as_deref(),
|
|
144
148
|
)
|
|
145
149
|
}
|