polars-sgt 0.1.0__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. polars_sgt-0.2.5/CHANGELOG.md +30 -0
  2. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/Cargo.lock +2 -1
  3. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/Cargo.toml +2 -1
  4. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/PKG-INFO +24 -3
  5. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/README.md +22 -2
  6. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/__init__.py +2 -0
  7. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/functions.py +212 -6
  8. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/pyproject.toml +1 -0
  9. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/expressions.rs +5 -1
  10. polars_sgt-0.2.5/src/sgt_transform.rs +393 -0
  11. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_sgt_transform.py +7 -7
  12. polars_sgt-0.2.5/tests/verify_sgt.py +103 -0
  13. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/uv.lock +14 -0
  14. polars_sgt-0.1.0/src/sgt_transform.rs +0 -304
  15. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.github/workflows/CI.yml +0 -0
  16. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.gitignore +0 -0
  17. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.python-version +0 -0
  18. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/.readthedocs.yaml +0 -0
  19. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/CODE_OF_CONDUCT.md +0 -0
  20. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/LICENSE +0 -0
  21. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/Makefile +0 -0
  22. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/assets/.DS_Store +0 -0
  23. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/assets/polars-business.png +0 -0
  24. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/bump_version.py +0 -0
  25. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/API.rst +0 -0
  26. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/Makefile +0 -0
  27. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/conf.py +0 -0
  28. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/index.rst +0 -0
  29. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/installation.rst +0 -0
  30. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/docs/requirements-docs.txt +0 -0
  31. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/dprint.json +0 -0
  32. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/licenses/NUMPY_LICENSE.txt +0 -0
  33. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/licenses/PANDAS_LICENSE.txt +0 -0
  34. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/.mypy.ini +0 -0
  35. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/_internal.pyi +0 -0
  36. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/namespace.py +0 -0
  37. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/py.typed +0 -0
  38. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/ranges.py +0 -0
  39. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/typing.py +0 -0
  40. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/polars_sgt/utils.py +0 -0
  41. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/requirements.txt +0 -0
  42. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/rust-toolchain.toml +0 -0
  43. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/arg_previous_greater.rs +0 -0
  44. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/format_localized.rs +0 -0
  45. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/lib.rs +0 -0
  46. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/month_delta.rs +0 -0
  47. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/timezone.rs +0 -0
  48. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/src/to_julian.rs +0 -0
  49. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/__init__.py +0 -0
  50. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/ceil_test.py +0 -0
  51. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/julian_date_test.py +0 -0
  52. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_benchmark.py +0 -0
  53. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_date_range.py +0 -0
  54. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_format_localized.py +0 -0
  55. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_is_busday.py +0 -0
  56. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_month_delta.py +0 -0
  57. {polars_sgt-0.1.0 → polars_sgt-0.2.5}/tests/test_timezone.py +0 -0
@@ -0,0 +1,30 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6
+ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
+
8
+ ## [0.2.5] - 2026-02-04
9
+
10
+ ### Added
11
+ - `use_tqdm` parameter to `sgt_transform_df` to control progress bar visibility.
12
+ - `keep_original_name` parameter to `sgt_transform_df` to optionally restore original sequence ID names.
13
+ - Support for multiple columns in `sequence_id_col` in `sgt_transform_df` (automatically concatenates and splits).
14
+
15
+ ### Fixed
16
+ - `sgt_transform_df` now correctly handles `group_cols=None` by processing the entire DataFrame.
17
+ - `sgt_transform_df` now correctly filters subsets dynamically based on unique values of `group_cols` instead of hardcoded columns.
18
+
19
+ ## [0.2.0] - 2026-02-02
20
+
21
+ ### Added
22
+ - Parallel processing support with `rayon` for SGT transform.
23
+ - Support for custom output struct field names via `sequence_id_name` and `state_name` parameters.
24
+
25
+ ### Changed
26
+ - **Major Performance Optimization**: Rewrote SGT transform to use O(n) group-based indexing instead of O(n*m) scanning. Throughput increased to ~1.4M+ records/second.
27
+ - **Struct Field Rename (BREAKING)**: Renamed `ngram_values` field in the output struct to `value` for consistency with current Polars version and parameter names.
28
+
29
+ ### Fixed
30
+ - Performance bottleneck on large datasets (10M+ records).
@@ -2010,7 +2010,7 @@ dependencies = [
2010
2010
 
2011
2011
  [[package]]
2012
2012
  name = "polars_sgt"
2013
- version = "0.1.0"
2013
+ version = "0.2.5"
2014
2014
  dependencies = [
2015
2015
  "chrono",
2016
2016
  "chrono-tz",
@@ -2019,6 +2019,7 @@ dependencies = [
2019
2019
  "polars-ops",
2020
2020
  "pyo3",
2021
2021
  "pyo3-polars",
2022
+ "rayon",
2022
2023
  "serde",
2023
2024
  ]
2024
2025
 
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "polars_sgt"
3
- version = "0.1.0"
3
+ version = "0.2.5"
4
4
  edition = "2021"
5
5
  authors = ["Zedd <lytran14789@gmail.com>", "Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>"]
6
6
  readme = "README.md"
@@ -19,4 +19,5 @@ chrono-tz = "0.10.4"
19
19
  polars = { version = "0.52.0", features = ["strings", "timezones"]}
20
20
  polars-ops = { version = "0.52.0", default-features = false }
21
21
  polars-arrow = { version = "0.52.0", default-features = false }
22
+ rayon = "1.10"
22
23
 
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: polars-sgt
3
- Version: 0.1.0
3
+ Version: 0.2.5
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
7
7
  Requires-Dist: maturin>=1.11.5
8
8
  Requires-Dist: polars>=1.36.1
9
9
  Requires-Dist: pytest>=8.4.2
10
+ Requires-Dist: tqdm>=4.66.0
10
11
  License-File: LICENSE
11
12
  Summary: Sequence Graph Transform (SGT) for Polars - Transform sequential data into weighted n-gram representations
12
13
  Author-email: Zedd <lytran14789@gmail.com>, Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>
@@ -91,10 +92,30 @@ result = df.select(
91
92
  features = result.select([
92
93
  pl.col("sgt_features").struct.field("sequence_id"),
93
94
  pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
94
- pl.col("sgt_features").struct.field("ngram_values").alias("weights"),
95
+ pl.col("sgt_features").struct.field("value").alias("weights"),
95
96
  ]).explode(["ngrams", "weights"])
96
97
 
97
98
  print(features)
99
+
100
+ #OR
101
+ result = df.select(
102
+ sgt.sgt_transform(
103
+ "session_id",
104
+ "event",
105
+ time_col="time",
106
+ deltatime="m", # minutes
107
+ kappa=3, # trigrams
108
+ time_penalty="inverse",
109
+ mode="l2",
110
+ alpha=0.5
111
+ ).alias("struct_type")
112
+ )
113
+ out = (
114
+ result
115
+ .unnest("struct_type")
116
+ .explode(["ngram_keys", "value"])
117
+ .filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
118
+ )
98
119
  ```
99
120
 
100
121
  ### With DateTime Columns
@@ -180,7 +201,7 @@ result = (
180
201
  Returns a Struct with three fields:
181
202
  - `sequence_id`: Original sequence identifier
182
203
  - `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
183
- - `ngram_values`: List of corresponding weights
204
+ - `value`: List of corresponding weights
184
205
 
185
206
  ## Additional DateTime Utilities
186
207
 
@@ -72,10 +72,30 @@ result = df.select(
72
72
  features = result.select([
73
73
  pl.col("sgt_features").struct.field("sequence_id"),
74
74
  pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
75
- pl.col("sgt_features").struct.field("ngram_values").alias("weights"),
75
+ pl.col("sgt_features").struct.field("value").alias("weights"),
76
76
  ]).explode(["ngrams", "weights"])
77
77
 
78
78
  print(features)
79
+
80
+ #OR
81
+ result = df.select(
82
+ sgt.sgt_transform(
83
+ "session_id",
84
+ "event",
85
+ time_col="time",
86
+ deltatime="m", # minutes
87
+ kappa=3, # trigrams
88
+ time_penalty="inverse",
89
+ mode="l2",
90
+ alpha=0.5
91
+ ).alias("struct_type")
92
+ )
93
+ out = (
94
+ result
95
+ .unnest("struct_type")
96
+ .explode(["ngram_keys", "value"])
97
+ .filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
98
+ )
79
99
  ```
80
100
 
81
101
  ### With DateTime Columns
@@ -161,7 +181,7 @@ result = (
161
181
  Returns a Struct with three fields:
162
182
  - `sequence_id`: Original sequence identifier
163
183
  - `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
164
- - `ngram_values`: List of corresponding weights
184
+ - `value`: List of corresponding weights
165
185
 
166
186
  ## Additional DateTime Utilities
167
187
 
@@ -11,6 +11,7 @@ from polars_sgt.functions import (
11
11
  month_delta,
12
12
  month_name,
13
13
  sgt_transform,
14
+ sgt_transform_df,
14
15
  to_julian_date,
15
16
  to_local_datetime,
16
17
  )
@@ -30,6 +31,7 @@ __all__ = [
30
31
  "month_delta",
31
32
  "month_name",
32
33
  "sgt_transform",
34
+ "sgt_transform_df",
33
35
  "to_julian_date",
34
36
  "to_local_datetime",
35
37
  ]
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import sys
4
4
  from datetime import date
5
5
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Literal
6
+ from typing import TYPE_CHECKING, Literal, Union, Any, Iterable
7
7
 
8
8
  import polars as pl
9
9
  from polars.plugins import register_plugin_function
@@ -675,7 +675,7 @@ def arg_previous_greater(expr: IntoExprColumn) -> pl.Expr:
675
675
 
676
676
 
677
677
  def sgt_transform(
678
- sequence_id_col: IntoExprColumn,
678
+ sequence_id_col: Union[IntoExprColumn, Iterable[IntoExprColumn]],
679
679
  state_col: IntoExprColumn,
680
680
  time_col: IntoExprColumn | None = None,
681
681
  *,
@@ -696,7 +696,7 @@ def sgt_transform(
696
696
  Parameters
697
697
  ----------
698
698
  sequence_id_col
699
- Column name containing sequence identifiers (groups)
699
+ Column or list of name/pl.col/pl.series containing sequence identifiers (groups)
700
700
  state_col
701
701
  Column name containing state/event values
702
702
  time_col
@@ -740,7 +740,7 @@ def sgt_transform(
740
740
  Struct expression containing:
741
741
  - sequence_id: Original sequence identifier
742
742
  - ngram_keys: List of n-gram strings
743
- - ngram_values: List of corresponding weights
743
+ - value: List of corresponding weights
744
744
 
745
745
  Examples
746
746
  --------
@@ -821,7 +821,7 @@ def sgt_transform(
821
821
  >>> df_features = result.select([
822
822
  ... pl.col("sgt_result").struct.field("sequence_id"),
823
823
  ... pl.col("sgt_result").struct.field("ngram_keys").alias("ngrams"),
824
- ... pl.col("sgt_result").struct.field("ngram_values").alias("weights"),
824
+ ... pl.col("sgt_result").struct.field("value").alias("weights"),
825
825
  ... ]).explode(["ngrams", "weights"])
826
826
 
827
827
  Notes
@@ -833,7 +833,12 @@ def sgt_transform(
833
833
  - Missing values in time columns are treated as 0
834
834
 
835
835
  """
836
- sequence_id_col = parse_into_expr(sequence_id_col)
836
+ # check if col is iterable
837
+ if isinstance(sequence_id_col, Iterable) and not isinstance(sequence_id_col, str):
838
+ sequence_id_cols = [parse_into_expr(col) for col in sequence_id_col]
839
+ sequence_id_col = pl.concat_str(sequence_id_cols, separator="--")
840
+ else:
841
+ sequence_id_col = parse_into_expr(sequence_id_col)
837
842
  state_col = parse_into_expr(state_col)
838
843
 
839
844
  if time_col is not None:
@@ -855,5 +860,206 @@ def sgt_transform(
855
860
  "alpha": alpha,
856
861
  "beta": beta,
857
862
  "deltatime": deltatime,
863
+ "sequence_id_name": None,
864
+ "state_name": None,
858
865
  },
859
866
  )
867
+
868
+ from tqdm import tqdm
869
+ def clean_name(n:str):
870
+ new_n = n.strip("{}").replace('"', '').replace(',', '--').replace(" ", "").lower()
871
+ return new_n
872
+ def clean_column_name(df: pl.DataFrame)->pl.DataFrame:
873
+ cols = [clean_name(c) for c in df.columns]
874
+ return df.rename(
875
+ {
876
+ oc: c for oc, c in zip(df.columns, cols)
877
+ }
878
+ )
879
+ def sgt_transform_df(
880
+ df: Union[pl.DataFrame, pl.LazyFrame],
881
+ sequence_id_col: Union[IntoExprColumn, Iterable[IntoExprColumn]],
882
+ state_col: IntoExprColumn,
883
+ time_col: IntoExprColumn | None = None,
884
+ group_cols: Union[IntoExprColumn, Iterable[IntoExprColumn]] | None = None,
885
+ *,
886
+ kappa: int = 1,
887
+ length_sensitive: bool = False,
888
+ mode: Literal["l1", "l2", "none"] = "l1",
889
+ time_penalty: Literal["inverse", "exponential", "linear", "power", "none"] = "inverse",
890
+ alpha: float = 1.0,
891
+ beta: float = 2.0,
892
+ deltatime: Literal["s", "m", "h", "d", "w", "month", "q", "y"] | None = None,
893
+ group_name: str = "sgt_",
894
+ use_tqdm: bool = True,
895
+ keep_original_name: bool = True,
896
+ ) -> Union[pl.DataFrame, dict[Any, pl.DataFrame]]:
897
+ """
898
+ Apply SGT transform to a DataFrame, optionally grouped by columns.
899
+
900
+ Parameters
901
+ ----------
902
+ df
903
+ Input DataFrame or LazyFrame
904
+ sequence_id_col
905
+ Column(s) identifying sequences. If multiple columns are provided,
906
+ they will be concatenated for processing and optionally restored.
907
+ state_col
908
+ Column containing states/events
909
+ time_col
910
+ Optional column containing timestamps
911
+ group_cols
912
+ Column(s) to group by before applying SGT.
913
+ If None, applies SGT to the whole DataFrame (or by existing sequence_id).
914
+ If provided, the DataFrame is split into subsets based on unique values of these columns.
915
+ kappa
916
+ SGT kappa parameter
917
+ length_sensitive
918
+ SGT length_sensitive parameter
919
+ mode
920
+ SGT mode parameter
921
+ time_penalty
922
+ SGT time_penalty parameter
923
+ alpha
924
+ SGT alpha parameter
925
+ beta
926
+ SGT beta parameter
927
+ deltatime
928
+ SGT deltatime parameter
929
+ group_name
930
+ Prefix for keys in the returned dictionary when group_cols is used.
931
+ use_tqdm
932
+ Whether to show a progress bar when iterating over groups.
933
+ keep_original_name
934
+ If True, and sequence_id_col was multiple columns, split the concatenated ID
935
+ back into original columns in the result.
936
+
937
+ Returns
938
+ -------
939
+ Union[pl.DataFrame, dict[Any, pl.DataFrame]]
940
+ If group_cols is None, returns a single DataFrame with SGT features.
941
+ If group_cols is provided, returns a dictionary where keys map to group values
942
+ and values are DataFrames with SGT features.
943
+ """
944
+
945
+ # Handle multiple sequence ID columns
946
+ is_multi_seq = False
947
+ original_seq_cols = []
948
+
949
+ if isinstance(sequence_id_col, (list, tuple)) and not isinstance(sequence_id_col, str):
950
+ is_multi_seq = True
951
+ original_seq_cols = [str(c) for c in sequence_id_col]
952
+
953
+ # If no grouping is requested, just run sgt_transform on the whole DF
954
+ if group_cols is None:
955
+ result = df.select(
956
+ sgt_transform(
957
+ sequence_id_col,
958
+ state_col,
959
+ time_col=time_col,
960
+ deltatime=deltatime,
961
+ kappa=kappa,
962
+ length_sensitive=length_sensitive,
963
+ mode=mode,
964
+ time_penalty=time_penalty,
965
+ alpha=alpha,
966
+ beta=beta,
967
+ ).alias("struct_type")
968
+ )
969
+
970
+ out = (
971
+ result
972
+ .unnest("struct_type")
973
+ .explode(["ngram_keys", "value"])
974
+ )
975
+
976
+ # Pivot to get features as columns
977
+ # Note: sequence_id in output is named "sequence_id" from the struct
978
+ df_sub = out.pivot(on="ngram_keys", index="sequence_id", values="value")
979
+ df_sub = clean_column_name(df_sub)
980
+
981
+ if keep_original_name:
982
+ # Identify the sequence id column in the result
983
+ # It comes out as "sequence_id" from the struct
984
+
985
+ if is_multi_seq:
986
+ # Split the "sequence_id" column back into original cols
987
+ # Assuming "--" separator as used in sgt_transform
988
+ split_exprs = [
989
+ pl.col("sequence_id").str.split_exact("--", len(original_seq_cols) - 1)
990
+ .struct.field(f"field_{i}")
991
+ .alias(col_name)
992
+ for i, col_name in enumerate(original_seq_cols)
993
+ ]
994
+ df_sub = df_sub.with_columns(split_exprs).drop("sequence_id")
995
+ elif isinstance(sequence_id_col, str) and sequence_id_col != "sequence_id":
996
+ # Rename back to original name if it's a single string col
997
+ df_sub = df_sub.rename({"sequence_id": sequence_id_col})
998
+
999
+ return df_sub
1000
+
1001
+ # If grouping is requested
1002
+ if isinstance(group_cols, str):
1003
+ group_cols = [group_cols]
1004
+
1005
+ # Get unique combinations of group columns
1006
+ subset_filters = df.select(group_cols).unique().to_dicts()
1007
+
1008
+ dfs = {}
1009
+
1010
+ iterator = tqdm(subset_filters, desc=f"Calculate SGT for each sub df in {group_name}") if use_tqdm else subset_filters
1011
+
1012
+ for i in iterator:
1013
+ # Create filter expression
1014
+ filter_expr = pl.lit(True)
1015
+ key_parts = []
1016
+ for col_name, val in i.items():
1017
+ filter_expr &= (pl.col(col_name) == val)
1018
+ key_parts.append(str(val))
1019
+
1020
+ key = f"{group_name}{'-'.join(key_parts)}"
1021
+
1022
+ dfsub = df.filter(filter_expr)
1023
+
1024
+ result = dfsub.select(
1025
+ sgt_transform(
1026
+ sequence_id_col,
1027
+ state_col,
1028
+ time_col=time_col,
1029
+ deltatime=deltatime,
1030
+ kappa=kappa,
1031
+ length_sensitive=length_sensitive,
1032
+ mode=mode,
1033
+ time_penalty=time_penalty,
1034
+ alpha=alpha,
1035
+ beta=beta,
1036
+ ).alias("struct_type")
1037
+ )
1038
+
1039
+ out = (
1040
+ result
1041
+ .unnest("struct_type")
1042
+ .explode(["ngram_keys", "value"])
1043
+ .with_columns(pl.lit(key).alias("kind"))
1044
+ )
1045
+
1046
+ # Pivot
1047
+ df_sub = out.pivot(on="ngram_keys", index=["kind", "sequence_id"], values="value")
1048
+ df_sub = clean_column_name(df_sub)
1049
+
1050
+ if keep_original_name:
1051
+ if is_multi_seq:
1052
+ # Split the "sequence_id" column back into original cols
1053
+ split_exprs = [
1054
+ pl.col("sequence_id").str.split_exact("--", len(original_seq_cols) - 1)
1055
+ .struct.field(f"field_{i}")
1056
+ .alias(col_name)
1057
+ for i, col_name in enumerate(original_seq_cols)
1058
+ ]
1059
+ df_sub = df_sub.with_columns(split_exprs).drop("sequence_id")
1060
+ elif isinstance(sequence_id_col, str) and sequence_id_col != "sequence_id":
1061
+ df_sub = df_sub.rename({"sequence_id": sequence_id_col})
1062
+
1063
+ dfs[key] = df_sub
1064
+
1065
+ return dfs
@@ -22,6 +22,7 @@ dependencies = [
22
22
  "maturin>=1.11.5",
23
23
  "polars>=1.36.1",
24
24
  "pytest>=8.4.2",
25
+ "tqdm>=4.66.0",
25
26
  ]
26
27
 
27
28
  [project.urls]
@@ -30,6 +30,8 @@ pub struct SgtTransformKwargs {
30
30
  alpha: f64,
31
31
  beta: f64,
32
32
  deltatime: Option<String>,
33
+ sequence_id_name: Option<String>,
34
+ state_name: Option<String>,
33
35
  }
34
36
 
35
37
  pub fn to_local_datetime_output(input_fields: &[Field]) -> PolarsResult<Field> {
@@ -122,7 +124,7 @@ fn sgt_transform_output(_input_fields: &[Field]) -> PolarsResult<Field> {
122
124
  let fields = vec![
123
125
  Field::new(PlSmallStr::from_str("sequence_id"), DataType::String),
124
126
  Field::new(PlSmallStr::from_str("ngram_keys"), DataType::List(Box::new(DataType::String))),
125
- Field::new(PlSmallStr::from_str("ngram_values"), DataType::List(Box::new(DataType::Float64))),
127
+ Field::new(PlSmallStr::from_str("value"), DataType::List(Box::new(DataType::Float64))),
126
128
  ];
127
129
  Ok(Field::new(
128
130
  PlSmallStr::from_str("sgt_result"),
@@ -141,5 +143,7 @@ fn sgt_transform(inputs: &[Series], kwargs: SgtTransformKwargs) -> PolarsResult<
141
143
  kwargs.alpha,
142
144
  kwargs.beta,
143
145
  kwargs.deltatime.as_deref(),
146
+ kwargs.sequence_id_name.as_deref(),
147
+ kwargs.state_name.as_deref(),
144
148
  )
145
149
  }