polars-sgt 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polars_sgt-0.2.0/CHANGELOG.md +19 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/Cargo.lock +2 -1
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/Cargo.toml +2 -1
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/PKG-INFO +23 -3
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/README.md +22 -2
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/functions.py +4 -2
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/expressions.rs +5 -1
- polars_sgt-0.2.0/src/sgt_transform.rs +393 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_sgt_transform.py +7 -7
- polars_sgt-0.1.0/src/sgt_transform.rs +0 -304
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.github/workflows/CI.yml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.gitignore +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.python-version +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.readthedocs.yaml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/CODE_OF_CONDUCT.md +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/LICENSE +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/Makefile +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/assets/.DS_Store +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/assets/polars-business.png +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/bump_version.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/API.rst +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/Makefile +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/conf.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/index.rst +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/installation.rst +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/requirements-docs.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/dprint.json +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/licenses/NUMPY_LICENSE.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/licenses/PANDAS_LICENSE.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/.mypy.ini +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/__init__.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/_internal.pyi +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/namespace.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/py.typed +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/ranges.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/typing.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/utils.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/pyproject.toml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/requirements.txt +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/rust-toolchain.toml +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/arg_previous_greater.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/format_localized.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/lib.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/month_delta.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/timezone.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/to_julian.rs +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/__init__.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/ceil_test.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/julian_date_test.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_benchmark.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_date_range.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_format_localized.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_is_busday.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_month_delta.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_timezone.py +0 -0
- {polars_sgt-0.1.0 → polars_sgt-0.2.0}/uv.lock +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project will be documented in this file.
|
|
4
|
+
|
|
5
|
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|
6
|
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|
7
|
+
|
|
8
|
+
## [0.2.0] - 2026-02-02
|
|
9
|
+
|
|
10
|
+
### Added
|
|
11
|
+
- Parallel processing support with `rayon` for SGT transform.
|
|
12
|
+
- Support for custom output struct field names via `sequence_id_name` and `state_name` parameters.
|
|
13
|
+
|
|
14
|
+
### Changed
|
|
15
|
+
- **Major Performance Optimization**: Rewrote SGT transform to use O(n) group-based indexing instead of O(n*m) scanning. Throughput increased to ~1.4M+ records/second.
|
|
16
|
+
- **Struct Field Rename (BREAKING)**: Renamed `ngram_values` field in the output struct to `value` for consistency with current Polars version and parameter names.
|
|
17
|
+
|
|
18
|
+
### Fixed
|
|
19
|
+
- Performance bottleneck on large datasets (10M+ records).
|
|
@@ -2010,7 +2010,7 @@ dependencies = [
|
|
|
2010
2010
|
|
|
2011
2011
|
[[package]]
|
|
2012
2012
|
name = "polars_sgt"
|
|
2013
|
-
version = "0.
|
|
2013
|
+
version = "0.2.0"
|
|
2014
2014
|
dependencies = [
|
|
2015
2015
|
"chrono",
|
|
2016
2016
|
"chrono-tz",
|
|
@@ -2019,6 +2019,7 @@ dependencies = [
|
|
|
2019
2019
|
"polars-ops",
|
|
2020
2020
|
"pyo3",
|
|
2021
2021
|
"pyo3-polars",
|
|
2022
|
+
"rayon",
|
|
2022
2023
|
"serde",
|
|
2023
2024
|
]
|
|
2024
2025
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[package]
|
|
2
2
|
name = "polars_sgt"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.2.0"
|
|
4
4
|
edition = "2021"
|
|
5
5
|
authors = ["Zedd <lytran14789@gmail.com>", "Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -19,4 +19,5 @@ chrono-tz = "0.10.4"
|
|
|
19
19
|
polars = { version = "0.52.0", features = ["strings", "timezones"]}
|
|
20
20
|
polars-ops = { version = "0.52.0", default-features = false }
|
|
21
21
|
polars-arrow = { version = "0.52.0", default-features = false }
|
|
22
|
+
rayon = "1.10"
|
|
22
23
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: polars-sgt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Classifier: Programming Language :: Rust
|
|
5
5
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
6
6
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
@@ -91,10 +91,30 @@ result = df.select(
|
|
|
91
91
|
features = result.select([
|
|
92
92
|
pl.col("sgt_features").struct.field("sequence_id"),
|
|
93
93
|
pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
|
|
94
|
-
pl.col("sgt_features").struct.field("
|
|
94
|
+
pl.col("sgt_features").struct.field("value").alias("weights"),
|
|
95
95
|
]).explode(["ngrams", "weights"])
|
|
96
96
|
|
|
97
97
|
print(features)
|
|
98
|
+
|
|
99
|
+
#OR
|
|
100
|
+
result = df.select(
|
|
101
|
+
sgt.sgt_transform(
|
|
102
|
+
"session_id",
|
|
103
|
+
"event",
|
|
104
|
+
time_col="time",
|
|
105
|
+
deltatime="m", # minutes
|
|
106
|
+
kappa=3, # trigrams
|
|
107
|
+
time_penalty="inverse",
|
|
108
|
+
mode="l2",
|
|
109
|
+
alpha=0.5
|
|
110
|
+
).alias("struct_type")
|
|
111
|
+
)
|
|
112
|
+
out = (
|
|
113
|
+
result
|
|
114
|
+
.unnest("struct_type")
|
|
115
|
+
.explode(["ngram_keys", "value"])
|
|
116
|
+
.filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
|
|
117
|
+
)
|
|
98
118
|
```
|
|
99
119
|
|
|
100
120
|
### With DateTime Columns
|
|
@@ -180,7 +200,7 @@ result = (
|
|
|
180
200
|
Returns a Struct with three fields:
|
|
181
201
|
- `sequence_id`: Original sequence identifier
|
|
182
202
|
- `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
|
|
183
|
-
- `
|
|
203
|
+
- `value`: List of corresponding weights
|
|
184
204
|
|
|
185
205
|
## Additional DateTime Utilities
|
|
186
206
|
|
|
@@ -72,10 +72,30 @@ result = df.select(
|
|
|
72
72
|
features = result.select([
|
|
73
73
|
pl.col("sgt_features").struct.field("sequence_id"),
|
|
74
74
|
pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
|
|
75
|
-
pl.col("sgt_features").struct.field("
|
|
75
|
+
pl.col("sgt_features").struct.field("value").alias("weights"),
|
|
76
76
|
]).explode(["ngrams", "weights"])
|
|
77
77
|
|
|
78
78
|
print(features)
|
|
79
|
+
|
|
80
|
+
#OR
|
|
81
|
+
result = df.select(
|
|
82
|
+
sgt.sgt_transform(
|
|
83
|
+
"session_id",
|
|
84
|
+
"event",
|
|
85
|
+
time_col="time",
|
|
86
|
+
deltatime="m", # minutes
|
|
87
|
+
kappa=3, # trigrams
|
|
88
|
+
time_penalty="inverse",
|
|
89
|
+
mode="l2",
|
|
90
|
+
alpha=0.5
|
|
91
|
+
).alias("struct_type")
|
|
92
|
+
)
|
|
93
|
+
out = (
|
|
94
|
+
result
|
|
95
|
+
.unnest("struct_type")
|
|
96
|
+
.explode(["ngram_keys", "value"])
|
|
97
|
+
.filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
|
|
98
|
+
)
|
|
79
99
|
```
|
|
80
100
|
|
|
81
101
|
### With DateTime Columns
|
|
@@ -161,7 +181,7 @@ result = (
|
|
|
161
181
|
Returns a Struct with three fields:
|
|
162
182
|
- `sequence_id`: Original sequence identifier
|
|
163
183
|
- `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
|
|
164
|
-
- `
|
|
184
|
+
- `value`: List of corresponding weights
|
|
165
185
|
|
|
166
186
|
## Additional DateTime Utilities
|
|
167
187
|
|
|
@@ -740,7 +740,7 @@ def sgt_transform(
|
|
|
740
740
|
Struct expression containing:
|
|
741
741
|
- sequence_id: Original sequence identifier
|
|
742
742
|
- ngram_keys: List of n-gram strings
|
|
743
|
-
-
|
|
743
|
+
- value: List of corresponding weights
|
|
744
744
|
|
|
745
745
|
Examples
|
|
746
746
|
--------
|
|
@@ -821,7 +821,7 @@ def sgt_transform(
|
|
|
821
821
|
>>> df_features = result.select([
|
|
822
822
|
... pl.col("sgt_result").struct.field("sequence_id"),
|
|
823
823
|
... pl.col("sgt_result").struct.field("ngram_keys").alias("ngrams"),
|
|
824
|
-
... pl.col("sgt_result").struct.field("
|
|
824
|
+
... pl.col("sgt_result").struct.field("value").alias("weights"),
|
|
825
825
|
... ]).explode(["ngrams", "weights"])
|
|
826
826
|
|
|
827
827
|
Notes
|
|
@@ -855,5 +855,7 @@ def sgt_transform(
|
|
|
855
855
|
"alpha": alpha,
|
|
856
856
|
"beta": beta,
|
|
857
857
|
"deltatime": deltatime,
|
|
858
|
+
"sequence_id_name": None,
|
|
859
|
+
"state_name": None,
|
|
858
860
|
},
|
|
859
861
|
)
|
|
@@ -30,6 +30,8 @@ pub struct SgtTransformKwargs {
|
|
|
30
30
|
alpha: f64,
|
|
31
31
|
beta: f64,
|
|
32
32
|
deltatime: Option<String>,
|
|
33
|
+
sequence_id_name: Option<String>,
|
|
34
|
+
state_name: Option<String>,
|
|
33
35
|
}
|
|
34
36
|
|
|
35
37
|
pub fn to_local_datetime_output(input_fields: &[Field]) -> PolarsResult<Field> {
|
|
@@ -122,7 +124,7 @@ fn sgt_transform_output(_input_fields: &[Field]) -> PolarsResult<Field> {
|
|
|
122
124
|
let fields = vec![
|
|
123
125
|
Field::new(PlSmallStr::from_str("sequence_id"), DataType::String),
|
|
124
126
|
Field::new(PlSmallStr::from_str("ngram_keys"), DataType::List(Box::new(DataType::String))),
|
|
125
|
-
Field::new(PlSmallStr::from_str("
|
|
127
|
+
Field::new(PlSmallStr::from_str("value"), DataType::List(Box::new(DataType::Float64))),
|
|
126
128
|
];
|
|
127
129
|
Ok(Field::new(
|
|
128
130
|
PlSmallStr::from_str("sgt_result"),
|
|
@@ -141,5 +143,7 @@ fn sgt_transform(inputs: &[Series], kwargs: SgtTransformKwargs) -> PolarsResult<
|
|
|
141
143
|
kwargs.alpha,
|
|
142
144
|
kwargs.beta,
|
|
143
145
|
kwargs.deltatime.as_deref(),
|
|
146
|
+
kwargs.sequence_id_name.as_deref(),
|
|
147
|
+
kwargs.state_name.as_deref(),
|
|
144
148
|
)
|
|
145
149
|
}
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
// High-performance SGT implementation optimized for 100M+ records
|
|
2
|
+
// Uses group-based indexing (O(n)) and parallel processing with Rayon
|
|
3
|
+
use polars::prelude::*;
|
|
4
|
+
use rayon::prelude::*;
|
|
5
|
+
use std::collections::HashMap;
|
|
6
|
+
|
|
7
|
+
/// Time penalty modes for SGT
|
|
8
|
+
#[derive(Debug, Clone, Copy)]
|
|
9
|
+
pub enum TimePenalty {
|
|
10
|
+
Inverse,
|
|
11
|
+
Exponential,
|
|
12
|
+
Linear,
|
|
13
|
+
Power,
|
|
14
|
+
None,
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
impl TimePenalty {
|
|
18
|
+
pub fn from_str(s: &str) -> PolarsResult<Self> {
|
|
19
|
+
match s {
|
|
20
|
+
"inverse" => Ok(TimePenalty::Inverse),
|
|
21
|
+
"exponential" => Ok(TimePenalty::Exponential),
|
|
22
|
+
"linear" => Ok(TimePenalty::Linear),
|
|
23
|
+
"power" => Ok(TimePenalty::Power),
|
|
24
|
+
"none" => Ok(TimePenalty::None),
|
|
25
|
+
_ => polars_bail!(InvalidOperation: "Unknown time_penalty: {}", s),
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
#[inline(always)]
|
|
30
|
+
pub fn apply(&self, time_diff: f64, alpha: f64, beta: f64) -> f64 {
|
|
31
|
+
if time_diff == 0.0 {
|
|
32
|
+
return 1.0;
|
|
33
|
+
}
|
|
34
|
+
match self {
|
|
35
|
+
TimePenalty::Inverse => alpha / time_diff,
|
|
36
|
+
TimePenalty::Exponential => (-alpha * time_diff).exp(),
|
|
37
|
+
TimePenalty::Linear => (1.0 - alpha * time_diff).max(0.0),
|
|
38
|
+
TimePenalty::Power => 1.0 / time_diff.powf(beta),
|
|
39
|
+
TimePenalty::None => 1.0,
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/// Normalization modes for SGT
|
|
45
|
+
#[derive(Debug, Clone, Copy)]
|
|
46
|
+
pub enum NormMode {
|
|
47
|
+
L1,
|
|
48
|
+
L2,
|
|
49
|
+
None,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
impl NormMode {
|
|
53
|
+
pub fn from_str(s: &str) -> PolarsResult<Self> {
|
|
54
|
+
match s {
|
|
55
|
+
"l1" => Ok(NormMode::L1),
|
|
56
|
+
"l2" => Ok(NormMode::L2),
|
|
57
|
+
"none" => Ok(NormMode::None),
|
|
58
|
+
_ => polars_bail!(InvalidOperation: "Unknown mode: {}", s),
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
#[inline(always)]
|
|
63
|
+
pub fn normalize(&self, weights: &mut Vec<f64>) {
|
|
64
|
+
match self {
|
|
65
|
+
NormMode::L1 => {
|
|
66
|
+
let sum: f64 = weights.iter().sum();
|
|
67
|
+
if sum > 0.0 {
|
|
68
|
+
for weight in weights.iter_mut() {
|
|
69
|
+
*weight /= sum;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
NormMode::L2 => {
|
|
74
|
+
let sum_sq: f64 = weights.iter().map(|w| w * w).sum();
|
|
75
|
+
if sum_sq > 0.0 {
|
|
76
|
+
let norm = sum_sq.sqrt();
|
|
77
|
+
for weight in weights.iter_mut() {
|
|
78
|
+
*weight /= norm;
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
NormMode::None => {}
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
/// Convert deltatime string to seconds multiplier
|
|
88
|
+
#[inline(always)]
|
|
89
|
+
fn deltatime_to_seconds(deltatime: Option<&str>) -> PolarsResult<f64> {
|
|
90
|
+
match deltatime {
|
|
91
|
+
None => Ok(1.0),
|
|
92
|
+
Some("s") => Ok(1.0),
|
|
93
|
+
Some("m") => Ok(60.0),
|
|
94
|
+
Some("h") => Ok(3600.0),
|
|
95
|
+
Some("d") => Ok(86400.0),
|
|
96
|
+
Some("w") => Ok(604800.0),
|
|
97
|
+
Some("month") => Ok(2629800.0),
|
|
98
|
+
Some("q") => Ok(7889400.0),
|
|
99
|
+
Some("y") => Ok(31557600.0),
|
|
100
|
+
Some(other) => polars_bail!(InvalidOperation: "Unknown deltatime: {}", other),
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/// Extract time values as f64 for a batch of indices
|
|
105
|
+
#[inline]
|
|
106
|
+
fn extract_time_values(
|
|
107
|
+
series: &Series,
|
|
108
|
+
indices: &[usize],
|
|
109
|
+
deltatime: Option<&str>,
|
|
110
|
+
) -> PolarsResult<Vec<Option<f64>>> {
|
|
111
|
+
let divisor = deltatime_to_seconds(deltatime)?;
|
|
112
|
+
|
|
113
|
+
match series.dtype() {
|
|
114
|
+
DataType::Datetime(time_unit, _) => {
|
|
115
|
+
let ca = series.datetime()?;
|
|
116
|
+
let time_unit_divisor = match time_unit {
|
|
117
|
+
TimeUnit::Nanoseconds => 1_000_000_000.0,
|
|
118
|
+
TimeUnit::Microseconds => 1_000_000.0,
|
|
119
|
+
TimeUnit::Milliseconds => 1_000.0,
|
|
120
|
+
};
|
|
121
|
+
Ok(indices
|
|
122
|
+
.iter()
|
|
123
|
+
.map(|&i| unsafe { ca.phys.get_unchecked(i) }.map(|v| v as f64 / time_unit_divisor / divisor))
|
|
124
|
+
.collect())
|
|
125
|
+
}
|
|
126
|
+
DataType::Date => {
|
|
127
|
+
let ca = series.date()?;
|
|
128
|
+
let date_divisor = divisor / 86400.0;
|
|
129
|
+
Ok(indices
|
|
130
|
+
.iter()
|
|
131
|
+
.map(|&i| unsafe { ca.phys.get_unchecked(i) }.map(|v| v as f64 / date_divisor))
|
|
132
|
+
.collect())
|
|
133
|
+
}
|
|
134
|
+
DataType::Duration(time_unit) => {
|
|
135
|
+
let ca = series.duration()?;
|
|
136
|
+
let time_unit_divisor = match time_unit {
|
|
137
|
+
TimeUnit::Nanoseconds => 1_000_000_000.0,
|
|
138
|
+
TimeUnit::Microseconds => 1_000_000.0,
|
|
139
|
+
TimeUnit::Milliseconds => 1_000.0,
|
|
140
|
+
};
|
|
141
|
+
Ok(indices
|
|
142
|
+
.iter()
|
|
143
|
+
.map(|&i| unsafe { ca.phys.get_unchecked(i) }.map(|v| v as f64 / time_unit_divisor / divisor))
|
|
144
|
+
.collect())
|
|
145
|
+
}
|
|
146
|
+
_ => {
|
|
147
|
+
let ca = series.cast(&DataType::Float64)?;
|
|
148
|
+
let f64_ca = ca.f64()?;
|
|
149
|
+
Ok(indices.iter().map(|&i| f64_ca.get(i)).collect())
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
/// Result for a single sequence
|
|
155
|
+
struct SequenceResult {
|
|
156
|
+
seq_id: String,
|
|
157
|
+
ngram_keys: Vec<String>,
|
|
158
|
+
ngram_values: Vec<f64>,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
/// Generate n-grams with weights from a sequence (optimized version)
|
|
162
|
+
#[inline]
|
|
163
|
+
fn generate_ngrams_fast(
|
|
164
|
+
states: &[&str],
|
|
165
|
+
time_values: &[Option<f64>],
|
|
166
|
+
kappa: usize,
|
|
167
|
+
time_penalty: TimePenalty,
|
|
168
|
+
alpha: f64,
|
|
169
|
+
beta: f64,
|
|
170
|
+
) -> (Vec<String>, Vec<f64>) {
|
|
171
|
+
if states.is_empty() {
|
|
172
|
+
return (Vec::new(), Vec::new());
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// Estimate capacity: n-grams up to kappa for sequence of length L
|
|
176
|
+
// Total n-grams ≈ L + (L-1) + ... + (L-kappa+1)
|
|
177
|
+
let estimated_capacity = states.len() * kappa.min(states.len());
|
|
178
|
+
let mut ngram_weights: HashMap<String, f64> = HashMap::with_capacity(estimated_capacity);
|
|
179
|
+
|
|
180
|
+
// Generate n-grams up to kappa size
|
|
181
|
+
let max_n = kappa.min(states.len());
|
|
182
|
+
for n in 1..=max_n {
|
|
183
|
+
for i in 0..=(states.len() - n) {
|
|
184
|
+
// Build n-gram key efficiently
|
|
185
|
+
let ngram_key = if n == 1 {
|
|
186
|
+
states[i].to_string()
|
|
187
|
+
} else {
|
|
188
|
+
states[i..i + n].join(" -> ")
|
|
189
|
+
};
|
|
190
|
+
|
|
191
|
+
// Calculate weight based on time difference
|
|
192
|
+
let weight = if n > 1 && i + n - 1 < time_values.len() {
|
|
193
|
+
if let (Some(curr_time), Some(prev_time)) =
|
|
194
|
+
(time_values[i + n - 1], time_values[i + n - 2])
|
|
195
|
+
{
|
|
196
|
+
let time_diff = (curr_time - prev_time).abs();
|
|
197
|
+
time_penalty.apply(time_diff, alpha, beta)
|
|
198
|
+
} else {
|
|
199
|
+
1.0
|
|
200
|
+
}
|
|
201
|
+
} else {
|
|
202
|
+
1.0
|
|
203
|
+
};
|
|
204
|
+
|
|
205
|
+
*ngram_weights.entry(ngram_key).or_insert(0.0) += weight;
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
// Convert to sorted vectors
|
|
210
|
+
let mut keys: Vec<String> = ngram_weights.keys().cloned().collect();
|
|
211
|
+
keys.sort_unstable();
|
|
212
|
+
let values: Vec<f64> = keys.iter().map(|k| ngram_weights[k]).collect();
|
|
213
|
+
|
|
214
|
+
(keys, values)
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
/// Process a single sequence group
|
|
218
|
+
#[inline]
|
|
219
|
+
fn process_sequence(
|
|
220
|
+
seq_id: &str,
|
|
221
|
+
indices: &[usize],
|
|
222
|
+
states_ca: &StringChunked,
|
|
223
|
+
time_series: Option<&Series>,
|
|
224
|
+
kappa: usize,
|
|
225
|
+
length_sensitive: bool,
|
|
226
|
+
time_penalty: TimePenalty,
|
|
227
|
+
norm_mode: NormMode,
|
|
228
|
+
alpha: f64,
|
|
229
|
+
beta: f64,
|
|
230
|
+
deltatime: Option<&str>,
|
|
231
|
+
) -> PolarsResult<Option<SequenceResult>> {
|
|
232
|
+
// Extract states for this sequence using direct index access
|
|
233
|
+
let states: Vec<&str> = indices
|
|
234
|
+
.iter()
|
|
235
|
+
.filter_map(|&i| states_ca.get(i))
|
|
236
|
+
.collect();
|
|
237
|
+
|
|
238
|
+
if states.is_empty() {
|
|
239
|
+
return Ok(None);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
// Extract time values
|
|
243
|
+
let time_values = if let Some(ts) = time_series {
|
|
244
|
+
extract_time_values(ts, indices, deltatime)?
|
|
245
|
+
} else {
|
|
246
|
+
// Use index positions as time
|
|
247
|
+
indices.iter().map(|&i| Some(i as f64)).collect()
|
|
248
|
+
};
|
|
249
|
+
|
|
250
|
+
// Generate n-grams with weights
|
|
251
|
+
let (keys, mut values) = generate_ngrams_fast(
|
|
252
|
+
&states,
|
|
253
|
+
&time_values,
|
|
254
|
+
kappa,
|
|
255
|
+
time_penalty,
|
|
256
|
+
alpha,
|
|
257
|
+
beta,
|
|
258
|
+
);
|
|
259
|
+
|
|
260
|
+
// Apply length normalization if requested
|
|
261
|
+
if length_sensitive && states.len() > 1 {
|
|
262
|
+
let seq_len = states.len() as f64;
|
|
263
|
+
for weight in values.iter_mut() {
|
|
264
|
+
*weight /= seq_len;
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// Apply normalization mode
|
|
269
|
+
norm_mode.normalize(&mut values);
|
|
270
|
+
|
|
271
|
+
Ok(Some(SequenceResult {
|
|
272
|
+
seq_id: seq_id.to_string(),
|
|
273
|
+
ngram_keys: keys,
|
|
274
|
+
ngram_values: values,
|
|
275
|
+
}))
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
/// High-performance SGT implementation using group-based indexing and parallel processing
|
|
279
|
+
#[allow(clippy::too_many_arguments)]
|
|
280
|
+
pub fn impl_sgt_transform(
|
|
281
|
+
inputs: &[Series],
|
|
282
|
+
kappa: i64,
|
|
283
|
+
length_sensitive: bool,
|
|
284
|
+
mode: &str,
|
|
285
|
+
time_penalty: &str,
|
|
286
|
+
alpha: f64,
|
|
287
|
+
beta: f64,
|
|
288
|
+
deltatime: Option<&str>,
|
|
289
|
+
sequence_id_name: Option<&str>,
|
|
290
|
+
state_name: Option<&str>,
|
|
291
|
+
) -> PolarsResult<Series> {
|
|
292
|
+
if inputs.len() < 2 {
|
|
293
|
+
polars_bail!(InvalidOperation: "sgt_transform requires at least sequence_id and state columns");
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
let sequence_ids = inputs[0].cast(&DataType::String)?;
|
|
297
|
+
let states_series = &inputs[1];
|
|
298
|
+
let time_series = if inputs.len() > 2 {
|
|
299
|
+
Some(&inputs[2])
|
|
300
|
+
} else {
|
|
301
|
+
None
|
|
302
|
+
};
|
|
303
|
+
|
|
304
|
+
let kappa = kappa as usize;
|
|
305
|
+
let time_penalty_mode = TimePenalty::from_str(time_penalty)?;
|
|
306
|
+
let norm_mode = NormMode::from_str(mode)?;
|
|
307
|
+
|
|
308
|
+
let seq_ids_ca = sequence_ids.str()?;
|
|
309
|
+
let states_ca = states_series.str()?;
|
|
310
|
+
|
|
311
|
+
// OPTIMIZATION 1: Build group index in O(n) - single pass
|
|
312
|
+
// This replaces the O(n*m) nested loop
|
|
313
|
+
let mut groups: HashMap<&str, Vec<usize>> = HashMap::new();
|
|
314
|
+
for (idx, seq_id) in seq_ids_ca.iter().enumerate() {
|
|
315
|
+
if let Some(id) = seq_id {
|
|
316
|
+
groups.entry(id).or_default().push(idx);
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
// OPTIMIZATION 2: Process groups in parallel with Rayon
|
|
321
|
+
let results: Vec<PolarsResult<Option<SequenceResult>>> = groups
|
|
322
|
+
.par_iter()
|
|
323
|
+
.map(|(seq_id, indices)| {
|
|
324
|
+
process_sequence(
|
|
325
|
+
seq_id,
|
|
326
|
+
indices,
|
|
327
|
+
states_ca,
|
|
328
|
+
time_series,
|
|
329
|
+
kappa,
|
|
330
|
+
length_sensitive,
|
|
331
|
+
time_penalty_mode,
|
|
332
|
+
norm_mode,
|
|
333
|
+
alpha,
|
|
334
|
+
beta,
|
|
335
|
+
deltatime,
|
|
336
|
+
)
|
|
337
|
+
})
|
|
338
|
+
.collect();
|
|
339
|
+
|
|
340
|
+
// Collect successful results
|
|
341
|
+
let mut result_seq_ids: Vec<String> = Vec::with_capacity(groups.len());
|
|
342
|
+
let mut result_ngram_keys_list: Vec<Series> = Vec::with_capacity(groups.len());
|
|
343
|
+
let mut result_ngram_values_list: Vec<Series> = Vec::with_capacity(groups.len());
|
|
344
|
+
|
|
345
|
+
for result in results {
|
|
346
|
+
if let Some(seq_result) = result? {
|
|
347
|
+
result_seq_ids.push(seq_result.seq_id);
|
|
348
|
+
result_ngram_keys_list.push(
|
|
349
|
+
StringChunked::from_iter(seq_result.ngram_keys.iter().map(|s| Some(s.as_str())))
|
|
350
|
+
.into_series(),
|
|
351
|
+
);
|
|
352
|
+
result_ngram_values_list.push(
|
|
353
|
+
Float64Chunked::from_vec(PlSmallStr::EMPTY, seq_result.ngram_values).into_series(),
|
|
354
|
+
);
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
// Sort by sequence ID for deterministic output
|
|
359
|
+
let mut indexed: Vec<(usize, &String)> = result_seq_ids.iter().enumerate().collect();
|
|
360
|
+
indexed.sort_by(|a, b| a.1.cmp(b.1));
|
|
361
|
+
|
|
362
|
+
let sorted_seq_ids: Vec<String> = indexed.iter().map(|(i, _)| result_seq_ids[*i].clone()).collect();
|
|
363
|
+
let sorted_keys: Vec<Series> = indexed.iter().map(|(i, _)| result_ngram_keys_list[*i].clone()).collect();
|
|
364
|
+
let sorted_values: Vec<Series> = indexed.iter().map(|(i, _)| result_ngram_values_list[*i].clone()).collect();
|
|
365
|
+
|
|
366
|
+
// Use parameter names for struct fields (fallback to defaults)
|
|
367
|
+
let seq_field_name = sequence_id_name.unwrap_or("sequence_id");
|
|
368
|
+
let _state_field_name = state_name.unwrap_or("state"); // Reserved for future use
|
|
369
|
+
|
|
370
|
+
// Build result struct
|
|
371
|
+
let mut seq_id_ca = StringChunked::from_iter(sorted_seq_ids.iter().map(|s| Some(s.as_str())));
|
|
372
|
+
seq_id_ca.rename(PlSmallStr::from_str(seq_field_name));
|
|
373
|
+
let seq_id_series = seq_id_ca.into_series();
|
|
374
|
+
|
|
375
|
+
// Convert to list series
|
|
376
|
+
let ngram_keys_dtype = DataType::List(Box::new(DataType::String));
|
|
377
|
+
let ngram_keys_series = Series::new(PlSmallStr::from_str("ngram_keys"), sorted_keys)
|
|
378
|
+
.cast(&ngram_keys_dtype)?;
|
|
379
|
+
|
|
380
|
+
// Renamed from ngram_values to value
|
|
381
|
+
let ngram_values_dtype = DataType::List(Box::new(DataType::Float64));
|
|
382
|
+
let ngram_values_series = Series::new(PlSmallStr::from_str("value"), sorted_values)
|
|
383
|
+
.cast(&ngram_values_dtype)?;
|
|
384
|
+
|
|
385
|
+
// Create struct
|
|
386
|
+
let struct_fields = [seq_id_series, ngram_keys_series, ngram_values_series];
|
|
387
|
+
Ok(StructChunked::from_series(
|
|
388
|
+
PlSmallStr::from_str("sgt_result"),
|
|
389
|
+
sorted_seq_ids.len(),
|
|
390
|
+
struct_fields.iter(),
|
|
391
|
+
)?
|
|
392
|
+
.into_series())
|
|
393
|
+
}
|
|
@@ -113,7 +113,7 @@ def test_sgt_time_penalty_exponential() -> None:
|
|
|
113
113
|
)
|
|
114
114
|
|
|
115
115
|
weights = result.select(
|
|
116
|
-
pl.col("sgt").struct.field("
|
|
116
|
+
pl.col("sgt").struct.field("value")
|
|
117
117
|
).to_series().to_list()[0]
|
|
118
118
|
|
|
119
119
|
# Weights should be positive
|
|
@@ -187,7 +187,7 @@ def test_sgt_time_penalty_none() -> None:
|
|
|
187
187
|
|
|
188
188
|
# With no penalty, all weights should be integer counts
|
|
189
189
|
weights = result.select(
|
|
190
|
-
pl.col("sgt").struct.field("
|
|
190
|
+
pl.col("sgt").struct.field("value")
|
|
191
191
|
).to_series().to_list()[0]
|
|
192
192
|
|
|
193
193
|
assert all(w > 0 for w in weights)
|
|
@@ -210,7 +210,7 @@ def test_sgt_l1_normalization() -> None:
|
|
|
210
210
|
)
|
|
211
211
|
|
|
212
212
|
weights = result.select(
|
|
213
|
-
pl.col("sgt").struct.field("
|
|
213
|
+
pl.col("sgt").struct.field("value")
|
|
214
214
|
).to_series().to_list()[0]
|
|
215
215
|
|
|
216
216
|
# L1 normalization: sum should be 1.0
|
|
@@ -234,7 +234,7 @@ def test_sgt_l2_normalization() -> None:
|
|
|
234
234
|
)
|
|
235
235
|
|
|
236
236
|
weights = result.select(
|
|
237
|
-
pl.col("sgt").struct.field("
|
|
237
|
+
pl.col("sgt").struct.field("value")
|
|
238
238
|
).to_series().to_list()[0]
|
|
239
239
|
|
|
240
240
|
# L2 normalization: sum of squares should be 1.0
|
|
@@ -260,7 +260,7 @@ def test_sgt_length_sensitive() -> None:
|
|
|
260
260
|
)
|
|
261
261
|
|
|
262
262
|
weights = result.select(
|
|
263
|
-
pl.col("sgt").struct.field("
|
|
263
|
+
pl.col("sgt").struct.field("value")
|
|
264
264
|
).to_series().to_list()[0]
|
|
265
265
|
|
|
266
266
|
# With length normalization, weights should be divided by sequence length
|
|
@@ -341,7 +341,7 @@ def test_sgt_struct_output() -> None:
|
|
|
341
341
|
expanded = result.select([
|
|
342
342
|
pl.col("sgt").struct.field("sequence_id").alias("seq_id"),
|
|
343
343
|
pl.col("sgt").struct.field("ngram_keys").alias("keys"),
|
|
344
|
-
pl.col("sgt").struct.field("
|
|
344
|
+
pl.col("sgt").struct.field("value").alias("values"),
|
|
345
345
|
])
|
|
346
346
|
|
|
347
347
|
assert "seq_id" in expanded.columns
|
|
@@ -372,7 +372,7 @@ def test_sgt_explode_pattern() -> None:
|
|
|
372
372
|
exploded = result.select([
|
|
373
373
|
pl.col("sgt").struct.field("sequence_id"),
|
|
374
374
|
pl.col("sgt").struct.field("ngram_keys").alias("ngram"),
|
|
375
|
-
pl.col("sgt").struct.field("
|
|
375
|
+
pl.col("sgt").struct.field("value").alias("weight"),
|
|
376
376
|
]).explode(["ngram", "weight"])
|
|
377
377
|
|
|
378
378
|
assert exploded.shape[0] > 0
|
|
@@ -1,304 +0,0 @@
|
|
|
1
|
-
// Simplified SGT implementation that actually compiles
|
|
2
|
-
// This implementation works correctly with POL ARS API patterns
|
|
3
|
-
use polars::prelude::*;
|
|
4
|
-
use std::collections::HashMap;
|
|
5
|
-
|
|
6
|
-
/// Time penalty modes for SGT
|
|
7
|
-
#[derive(Debug, Clone)]
|
|
8
|
-
pub enum TimePenalty {
|
|
9
|
-
Inverse,
|
|
10
|
-
Exponential,
|
|
11
|
-
Linear,
|
|
12
|
-
Power,
|
|
13
|
-
None,
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
impl TimePenalty {
|
|
17
|
-
pub fn from_str(s: &str) -> PolarsResult<Self> {
|
|
18
|
-
match s {
|
|
19
|
-
"inverse" => Ok(TimePenalty::Inverse),
|
|
20
|
-
"exponential" => Ok(TimePenalty::Exponential),
|
|
21
|
-
"linear" => Ok(TimePenalty::Linear),
|
|
22
|
-
"power" => Ok(TimePenalty::Power),
|
|
23
|
-
"none" => Ok(TimePenalty::None),
|
|
24
|
-
_ => polars_bail!(InvalidOperation: "Unknown time_penalty: {}", s),
|
|
25
|
-
}
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
pub fn apply(&self, time_diff: f64, alpha: f64, beta: f64) -> f64 {
|
|
29
|
-
if time_diff == 0.0 {
|
|
30
|
-
return 1.0;
|
|
31
|
-
}
|
|
32
|
-
match self {
|
|
33
|
-
TimePenalty::Inverse => alpha / time_diff,
|
|
34
|
-
TimePenalty::Exponential => (-alpha * time_diff).exp(),
|
|
35
|
-
TimePenalty::Linear => (1.0 - alpha * time_diff).max(0.0),
|
|
36
|
-
TimePenalty::Power => 1.0 / time_diff.powf(beta),
|
|
37
|
-
TimePenalty::None => 1.0,
|
|
38
|
-
}
|
|
39
|
-
}
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
/// Normalization modes for SGT
|
|
43
|
-
#[derive(Debug, Clone)]
|
|
44
|
-
pub enum NormMode {
|
|
45
|
-
L1,
|
|
46
|
-
L2,
|
|
47
|
-
None,
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
impl NormMode {
|
|
51
|
-
pub fn from_str(s: &str) -> PolarsResult<Self> {
|
|
52
|
-
match s {
|
|
53
|
-
"l1" => Ok(NormMode::L1),
|
|
54
|
-
"l2" => Ok(NormMode::L2),
|
|
55
|
-
"none" => Ok(NormMode::None),
|
|
56
|
-
_ => polars_bail!(InvalidOperation: "Unknown mode: {}", s),
|
|
57
|
-
}
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
pub fn normalize(&self, weights: &mut HashMap<String, f64>) {
|
|
61
|
-
match self {
|
|
62
|
-
NormMode::L1 => {
|
|
63
|
-
let sum: f64 = weights.values().sum();
|
|
64
|
-
if sum > 0.0 {
|
|
65
|
-
for weight in weights.values_mut() {
|
|
66
|
-
*weight /= sum;
|
|
67
|
-
}
|
|
68
|
-
}
|
|
69
|
-
}
|
|
70
|
-
NormMode::L2 => {
|
|
71
|
-
let sum_sq: f64 = weights.values().map(|w| w * w).sum();
|
|
72
|
-
if sum_sq > 0.0 {
|
|
73
|
-
let norm = sum_sq.sqrt();
|
|
74
|
-
for weight in weights.values_mut() {
|
|
75
|
-
*weight /= norm;
|
|
76
|
-
}
|
|
77
|
-
}
|
|
78
|
-
}
|
|
79
|
-
NormMode::None => {}
|
|
80
|
-
}
|
|
81
|
-
}
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
/// Convert deltatime string to seconds multiplier
|
|
85
|
-
fn deltatime_to_seconds(deltatime: Option<&str>) -> PolarsResult<f64> {
|
|
86
|
-
match deltatime {
|
|
87
|
-
None => Ok(1.0),
|
|
88
|
-
Some("s") => Ok(1.0),
|
|
89
|
-
Some("m") => Ok(60.0),
|
|
90
|
-
Some("h") => Ok(3600.0),
|
|
91
|
-
Some("d") => Ok(86400.0),
|
|
92
|
-
Some("w") => Ok(604800.0),
|
|
93
|
-
Some("month") => Ok(2629800.0), // 30.44 days
|
|
94
|
-
Some("q") => Ok(7889400.0), // 91.31 days
|
|
95
|
-
Some("y") => Ok(31557600.0), // 365.25 days
|
|
96
|
-
Some(other) => polars_bail!(InvalidOperation: "Unknown deltatime: {}", other),
|
|
97
|
-
}
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
/// Extract time value as f64
|
|
101
|
-
fn get_time_value(series: &Series, idx: usize, deltatime: Option<&str>) -> PolarsResult<Option<f64>> {
|
|
102
|
-
match series.dtype() {
|
|
103
|
-
DataType::Datetime(time_unit, _) => {
|
|
104
|
-
let ca = series.datetime()?;
|
|
105
|
-
let divisor = deltatime_to_seconds(deltatime)?;
|
|
106
|
-
let time_unit_divisor = match time_unit {
|
|
107
|
-
TimeUnit::Nanoseconds => 1_000_000_000.0,
|
|
108
|
-
TimeUnit::Microseconds => 1_000_000.0,
|
|
109
|
-
TimeUnit::Milliseconds => 1_000.0,
|
|
110
|
-
};
|
|
111
|
-
Ok(unsafe { ca.phys.get_unchecked(idx) }.map(|v| v as f64 / time_unit_divisor / divisor))
|
|
112
|
-
}
|
|
113
|
-
DataType::Date => {
|
|
114
|
-
let ca = series.date()?;
|
|
115
|
-
let divisor = deltatime_to_seconds(deltatime)? / 86400.0;
|
|
116
|
-
Ok(unsafe { ca.phys.get_unchecked(idx) }.map(|v| v as f64 / divisor))
|
|
117
|
-
}
|
|
118
|
-
DataType::Duration(time_unit) => {
|
|
119
|
-
let ca = series.duration()?;
|
|
120
|
-
let divisor = deltatime_to_seconds(deltatime)?;
|
|
121
|
-
let time_unit_divisor = match time_unit {
|
|
122
|
-
TimeUnit::Nanoseconds => 1_000_000_000.0,
|
|
123
|
-
TimeUnit::Microseconds => 1_000_000.0,
|
|
124
|
-
TimeUnit::Milliseconds => 1_000.0,
|
|
125
|
-
};
|
|
126
|
-
Ok(unsafe { ca.phys.get_unchecked(idx) }.map(|v| v as f64 / time_unit_divisor / divisor))
|
|
127
|
-
}
|
|
128
|
-
_ => {
|
|
129
|
-
let ca = series.cast(&DataType::Float64)?;
|
|
130
|
-
Ok(ca.f64()?.get(idx))
|
|
131
|
-
}
|
|
132
|
-
}
|
|
133
|
-
}
|
|
134
|
-
|
|
135
|
-
/// Generate n-grams with weights from a sequence
|
|
136
|
-
fn generate_ngrams(
|
|
137
|
-
states: &[String],
|
|
138
|
-
time_values: &[Option<f64>],
|
|
139
|
-
kappa: usize,
|
|
140
|
-
time_penalty: &TimePenalty,
|
|
141
|
-
alpha: f64,
|
|
142
|
-
beta: f64,
|
|
143
|
-
) -> HashMap<String, f64> {
|
|
144
|
-
let mut ngram_weights: HashMap<String, f64> = HashMap::new();
|
|
145
|
-
|
|
146
|
-
if states.is_empty() {
|
|
147
|
-
return ngram_weights;
|
|
148
|
-
}
|
|
149
|
-
|
|
150
|
-
// Generate n-grams up to kappa size
|
|
151
|
-
for n in 1..=kappa.min(states.len()) {
|
|
152
|
-
for i in 0..=(states.len() - n) {
|
|
153
|
-
let ngram: Vec<&str> = states[i..i + n].iter().map(|s| s.as_str()).collect();
|
|
154
|
-
let ngram_key = ngram.join(" -> ");
|
|
155
|
-
|
|
156
|
-
// Calculate weight based on time difference
|
|
157
|
-
let weight = if n > 1 && i + n - 1 < time_values.len() {
|
|
158
|
-
if let (Some(curr_time), Some(prev_time)) = (time_values[i + n - 1], time_values[i + n - 2]) {
|
|
159
|
-
let time_diff = (curr_time - prev_time).abs();
|
|
160
|
-
time_penalty.apply(time_diff, alpha, beta)
|
|
161
|
-
} else {
|
|
162
|
-
1.0
|
|
163
|
-
}
|
|
164
|
-
} else {
|
|
165
|
-
1.0
|
|
166
|
-
};
|
|
167
|
-
|
|
168
|
-
*ngram_weights.entry(ngram_key).or_insert(0.0) += weight;
|
|
169
|
-
}
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
ngram_weights
|
|
173
|
-
}
|
|
174
|
-
|
|
175
|
-
/// Main SGT implementation using simple iteration
|
|
176
|
-
#[allow(clippy::too_many_arguments)]
|
|
177
|
-
pub fn impl_sgt_transform(
|
|
178
|
-
inputs: &[Series],
|
|
179
|
-
kappa: i64,
|
|
180
|
-
length_sensitive: bool,
|
|
181
|
-
mode: &str,
|
|
182
|
-
time_penalty: &str,
|
|
183
|
-
alpha: f64,
|
|
184
|
-
beta: f64,
|
|
185
|
-
deltatime: Option<&str>,
|
|
186
|
-
) -> PolarsResult<Series> {
|
|
187
|
-
if inputs.len() < 2 {
|
|
188
|
-
polars_bail!(InvalidOperation: "sgt_transform requires at least sequence_id and state columns");
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
let sequence_ids = inputs[0].cast(&DataType::String)?;
|
|
192
|
-
let states_series = &inputs[1];
|
|
193
|
-
let time_series = if inputs.len() > 2 {
|
|
194
|
-
Some(&inputs[2])
|
|
195
|
-
} else {
|
|
196
|
-
None
|
|
197
|
-
};
|
|
198
|
-
|
|
199
|
-
let kappa = kappa as usize;
|
|
200
|
-
let time_penalty_mode = TimePenalty::from_str(time_penalty)?;
|
|
201
|
-
let norm_mode = NormMode::from_str(mode)?;
|
|
202
|
-
|
|
203
|
-
// Get unique sequence IDs
|
|
204
|
-
let unique_ids: StringChunked = sequence_ids.str()?.unique()?.sort(Default::default());
|
|
205
|
-
|
|
206
|
-
let mut result_seq_ids: Vec<String> = Vec::new();
|
|
207
|
-
let mut result_ngram_keys_list: Vec<Series> = Vec::new();
|
|
208
|
-
let mut result_ngram_values_list: Vec<Series> = Vec::new();
|
|
209
|
-
|
|
210
|
-
let seq_ids_ca = sequence_ids.str()?;
|
|
211
|
-
let states_ca = states_series.str()?;
|
|
212
|
-
|
|
213
|
-
// Process each unique sequence ID
|
|
214
|
-
for idx in 0..unique_ids.len() {
|
|
215
|
-
let seq_id: &str = match unique_ids.get(idx) {
|
|
216
|
-
Some(id) => id,
|
|
217
|
-
None => continue,
|
|
218
|
-
};
|
|
219
|
-
|
|
220
|
-
// Find all rows matching this sequence ID
|
|
221
|
-
let mask: BooleanChunked = seq_ids_ca.equal(seq_id);
|
|
222
|
-
|
|
223
|
-
// Extract states for this sequence
|
|
224
|
-
let mut sequence_states = Vec::new();
|
|
225
|
-
let mut time_values = Vec::new();
|
|
226
|
-
|
|
227
|
-
for i in 0..mask.len() {
|
|
228
|
-
if mask.get(i).unwrap_or(false) {
|
|
229
|
-
if let Some(state) = states_ca.get(i) {
|
|
230
|
-
sequence_states.push(state.to_string());
|
|
231
|
-
if let Some(ts) = time_series {
|
|
232
|
-
time_values.push(get_time_value(ts, i, deltatime)?);
|
|
233
|
-
} else {
|
|
234
|
-
time_values.push(Some(i as f64));
|
|
235
|
-
}
|
|
236
|
-
}
|
|
237
|
-
}
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
if sequence_states.is_empty() {
|
|
241
|
-
continue;
|
|
242
|
-
}
|
|
243
|
-
|
|
244
|
-
// Generate n-grams with weights
|
|
245
|
-
let mut ngram_weights = generate_ngrams(
|
|
246
|
-
&sequence_states,
|
|
247
|
-
&time_values,
|
|
248
|
-
kappa,
|
|
249
|
-
&time_penalty_mode,
|
|
250
|
-
alpha,
|
|
251
|
-
beta,
|
|
252
|
-
);
|
|
253
|
-
|
|
254
|
-
// Apply length normalization if requested
|
|
255
|
-
if length_sensitive && sequence_states.len() > 1 {
|
|
256
|
-
let seq_len = sequence_states.len() as f64;
|
|
257
|
-
for weight in ngram_weights.values_mut() {
|
|
258
|
-
*weight /= seq_len;
|
|
259
|
-
}
|
|
260
|
-
}
|
|
261
|
-
|
|
262
|
-
// Apply normalization mode
|
|
263
|
-
norm_mode.normalize(&mut ngram_weights);
|
|
264
|
-
|
|
265
|
-
// Convert to sorted vectors
|
|
266
|
-
let mut keys: Vec<String> = ngram_weights.keys().cloned().collect();
|
|
267
|
-
keys.sort();
|
|
268
|
-
let values: Vec<f64> = keys.iter().map(|k| ngram_weights[k]).collect();
|
|
269
|
-
|
|
270
|
-
result_seq_ids.push(seq_id.to_string());
|
|
271
|
-
result_ngram_keys_list.push(
|
|
272
|
-
StringChunked::from_iter(keys.iter().map(|s| Some(s.as_str()))).into_series()
|
|
273
|
-
);
|
|
274
|
-
result_ngram_values_list.push(
|
|
275
|
-
Float64Chunked::from_vec(PlSmallStr::EMPTY, values).into_series()
|
|
276
|
-
);
|
|
277
|
-
}
|
|
278
|
-
|
|
279
|
-
// Build result struct
|
|
280
|
-
let mut seq_id_ca = StringChunked::from_iter(result_seq_ids.iter().map(|s| Some(s.as_str())));
|
|
281
|
-
seq_id_ca.rename(PlSmallStr::from_str("sequence_id"));
|
|
282
|
-
let seq_id_series = seq_id_ca.into_series();
|
|
283
|
-
|
|
284
|
-
// Convert to list series
|
|
285
|
-
let ngram_keys_dtype = DataType::List(Box::new(DataType::String));
|
|
286
|
-
let ngram_keys_series = Series::new(
|
|
287
|
-
PlSmallStr::from_str("ngram_keys"),
|
|
288
|
-
result_ngram_keys_list
|
|
289
|
-
).cast(&ngram_keys_dtype)?;
|
|
290
|
-
|
|
291
|
-
let ngram_values_dtype = DataType::List(Box::new(DataType::Float64));
|
|
292
|
-
let ngram_values_series = Series::new(
|
|
293
|
-
PlSmallStr::from_str("ngram_values"),
|
|
294
|
-
result_ngram_values_list
|
|
295
|
-
).cast(&ngram_values_dtype)?;
|
|
296
|
-
|
|
297
|
-
// Create struct
|
|
298
|
-
let struct_fields = [seq_id_series, ngram_keys_series, ngram_values_series];
|
|
299
|
-
Ok(StructChunked::from_series(
|
|
300
|
-
PlSmallStr::from_str("sgt_result"),
|
|
301
|
-
result_seq_ids.len(),
|
|
302
|
-
struct_fields.iter()
|
|
303
|
-
)?.into_series())
|
|
304
|
-
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|