polars-sgt 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. polars_sgt-0.2.0/CHANGELOG.md +19 -0
  2. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/Cargo.lock +2 -1
  3. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/Cargo.toml +2 -1
  4. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/PKG-INFO +23 -3
  5. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/README.md +22 -2
  6. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/functions.py +4 -2
  7. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/expressions.rs +5 -1
  8. polars_sgt-0.2.0/src/sgt_transform.rs +393 -0
  9. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_sgt_transform.py +7 -7
  10. polars_sgt-0.1.0/src/sgt_transform.rs +0 -304
  11. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.github/workflows/CI.yml +0 -0
  12. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.gitignore +0 -0
  13. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.python-version +0 -0
  14. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/.readthedocs.yaml +0 -0
  15. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/CODE_OF_CONDUCT.md +0 -0
  16. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/LICENSE +0 -0
  17. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/Makefile +0 -0
  18. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/assets/.DS_Store +0 -0
  19. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/assets/polars-business.png +0 -0
  20. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/bump_version.py +0 -0
  21. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/API.rst +0 -0
  22. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/Makefile +0 -0
  23. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/conf.py +0 -0
  24. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/index.rst +0 -0
  25. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/installation.rst +0 -0
  26. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/docs/requirements-docs.txt +0 -0
  27. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/dprint.json +0 -0
  28. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/licenses/NUMPY_LICENSE.txt +0 -0
  29. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/licenses/PANDAS_LICENSE.txt +0 -0
  30. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/.mypy.ini +0 -0
  31. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/__init__.py +0 -0
  32. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/_internal.pyi +0 -0
  33. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/namespace.py +0 -0
  34. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/py.typed +0 -0
  35. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/ranges.py +0 -0
  36. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/typing.py +0 -0
  37. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/polars_sgt/utils.py +0 -0
  38. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/pyproject.toml +0 -0
  39. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/requirements.txt +0 -0
  40. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/rust-toolchain.toml +0 -0
  41. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/arg_previous_greater.rs +0 -0
  42. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/format_localized.rs +0 -0
  43. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/lib.rs +0 -0
  44. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/month_delta.rs +0 -0
  45. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/timezone.rs +0 -0
  46. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/src/to_julian.rs +0 -0
  47. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/__init__.py +0 -0
  48. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/ceil_test.py +0 -0
  49. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/julian_date_test.py +0 -0
  50. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_benchmark.py +0 -0
  51. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_date_range.py +0 -0
  52. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_format_localized.py +0 -0
  53. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_is_busday.py +0 -0
  54. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_month_delta.py +0 -0
  55. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/tests/test_timezone.py +0 -0
  56. {polars_sgt-0.1.0 → polars_sgt-0.2.0}/uv.lock +0 -0
@@ -0,0 +1,19 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6
+ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
+
8
+ ## [0.2.0] - 2026-02-02
9
+
10
+ ### Added
11
+ - Parallel processing support with `rayon` for SGT transform.
12
+ - Support for custom output struct field names via `sequence_id_name` and `state_name` parameters.
13
+
14
+ ### Changed
15
+ - **Major Performance Optimization**: Rewrote SGT transform to use O(n) group-based indexing instead of O(n*m) scanning. Throughput increased to ~1.4M+ records/second.
16
+ - **Struct Field Rename (BREAKING)**: Renamed `ngram_values` field in the output struct to `value` for consistency with current Polars version and parameter names.
17
+
18
+ ### Fixed
19
+ - Performance bottleneck on large datasets (10M+ records).
@@ -2010,7 +2010,7 @@ dependencies = [
2010
2010
 
2011
2011
  [[package]]
2012
2012
  name = "polars_sgt"
2013
- version = "0.1.0"
2013
+ version = "0.2.0"
2014
2014
  dependencies = [
2015
2015
  "chrono",
2016
2016
  "chrono-tz",
@@ -2019,6 +2019,7 @@ dependencies = [
2019
2019
  "polars-ops",
2020
2020
  "pyo3",
2021
2021
  "pyo3-polars",
2022
+ "rayon",
2022
2023
  "serde",
2023
2024
  ]
2024
2025
 
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "polars_sgt"
3
- version = "0.1.0"
3
+ version = "0.2.0"
4
4
  edition = "2021"
5
5
  authors = ["Zedd <lytran14789@gmail.com>", "Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>"]
6
6
  readme = "README.md"
@@ -19,4 +19,5 @@ chrono-tz = "0.10.4"
19
19
  polars = { version = "0.52.0", features = ["strings", "timezones"]}
20
20
  polars-ops = { version = "0.52.0", default-features = false }
21
21
  polars-arrow = { version = "0.52.0", default-features = false }
22
+ rayon = "1.10"
22
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: polars-sgt
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -91,10 +91,30 @@ result = df.select(
91
91
  features = result.select([
92
92
  pl.col("sgt_features").struct.field("sequence_id"),
93
93
  pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
94
- pl.col("sgt_features").struct.field("ngram_values").alias("weights"),
94
+ pl.col("sgt_features").struct.field("value").alias("weights"),
95
95
  ]).explode(["ngrams", "weights"])
96
96
 
97
97
  print(features)
98
+
99
+ #OR
100
+ result = df.select(
101
+ sgt.sgt_transform(
102
+ "session_id",
103
+ "event",
104
+ time_col="time",
105
+ deltatime="m", # minutes
106
+ kappa=3, # trigrams
107
+ time_penalty="inverse",
108
+ mode="l2",
109
+ alpha=0.5
110
+ ).alias("struct_type")
111
+ )
112
+ out = (
113
+ result
114
+ .unnest("struct_type")
115
+ .explode(["ngram_keys", "value"])
116
+ .filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
117
+ )
98
118
  ```
99
119
 
100
120
  ### With DateTime Columns
@@ -180,7 +200,7 @@ result = (
180
200
  Returns a Struct with three fields:
181
201
  - `sequence_id`: Original sequence identifier
182
202
  - `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
183
- - `ngram_values`: List of corresponding weights
203
+ - `value`: List of corresponding weights
184
204
 
185
205
  ## Additional DateTime Utilities
186
206
 
@@ -72,10 +72,30 @@ result = df.select(
72
72
  features = result.select([
73
73
  pl.col("sgt_features").struct.field("sequence_id"),
74
74
  pl.col("sgt_features").struct.field("ngram_keys").alias("ngrams"),
75
- pl.col("sgt_features").struct.field("ngram_values").alias("weights"),
75
+ pl.col("sgt_features").struct.field("value").alias("weights"),
76
76
  ]).explode(["ngrams", "weights"])
77
77
 
78
78
  print(features)
79
+
80
+ #OR
81
+ result = df.select(
82
+ sgt.sgt_transform(
83
+ "session_id",
84
+ "event",
85
+ time_col="time",
86
+ deltatime="m", # minutes
87
+ kappa=3, # trigrams
88
+ time_penalty="inverse",
89
+ mode="l2",
90
+ alpha=0.5
91
+ ).alias("struct_type")
92
+ )
93
+ out = (
94
+ result
95
+ .unnest("struct_type")
96
+ .explode(["ngram_keys", "value"])
97
+ .filter(pl.col("ngram_keys").str.split("->").list.len() > 0)
98
+ )
79
99
  ```
80
100
 
81
101
  ### With DateTime Columns
@@ -161,7 +181,7 @@ result = (
161
181
  Returns a Struct with three fields:
162
182
  - `sequence_id`: Original sequence identifier
163
183
  - `ngram_keys`: List of n-gram strings (e.g., "login -> view -> purchase")
164
- - `ngram_values`: List of corresponding weights
184
+ - `value`: List of corresponding weights
165
185
 
166
186
  ## Additional DateTime Utilities
167
187
 
@@ -740,7 +740,7 @@ def sgt_transform(
740
740
  Struct expression containing:
741
741
  - sequence_id: Original sequence identifier
742
742
  - ngram_keys: List of n-gram strings
743
- - ngram_values: List of corresponding weights
743
+ - value: List of corresponding weights
744
744
 
745
745
  Examples
746
746
  --------
@@ -821,7 +821,7 @@ def sgt_transform(
821
821
  >>> df_features = result.select([
822
822
  ... pl.col("sgt_result").struct.field("sequence_id"),
823
823
  ... pl.col("sgt_result").struct.field("ngram_keys").alias("ngrams"),
824
- ... pl.col("sgt_result").struct.field("ngram_values").alias("weights"),
824
+ ... pl.col("sgt_result").struct.field("value").alias("weights"),
825
825
  ... ]).explode(["ngrams", "weights"])
826
826
 
827
827
  Notes
@@ -855,5 +855,7 @@ def sgt_transform(
855
855
  "alpha": alpha,
856
856
  "beta": beta,
857
857
  "deltatime": deltatime,
858
+ "sequence_id_name": None,
859
+ "state_name": None,
858
860
  },
859
861
  )
@@ -30,6 +30,8 @@ pub struct SgtTransformKwargs {
30
30
  alpha: f64,
31
31
  beta: f64,
32
32
  deltatime: Option<String>,
33
+ sequence_id_name: Option<String>,
34
+ state_name: Option<String>,
33
35
  }
34
36
 
35
37
  pub fn to_local_datetime_output(input_fields: &[Field]) -> PolarsResult<Field> {
@@ -122,7 +124,7 @@ fn sgt_transform_output(_input_fields: &[Field]) -> PolarsResult<Field> {
122
124
  let fields = vec![
123
125
  Field::new(PlSmallStr::from_str("sequence_id"), DataType::String),
124
126
  Field::new(PlSmallStr::from_str("ngram_keys"), DataType::List(Box::new(DataType::String))),
125
- Field::new(PlSmallStr::from_str("ngram_values"), DataType::List(Box::new(DataType::Float64))),
127
+ Field::new(PlSmallStr::from_str("value"), DataType::List(Box::new(DataType::Float64))),
126
128
  ];
127
129
  Ok(Field::new(
128
130
  PlSmallStr::from_str("sgt_result"),
@@ -141,5 +143,7 @@ fn sgt_transform(inputs: &[Series], kwargs: SgtTransformKwargs) -> PolarsResult<
141
143
  kwargs.alpha,
142
144
  kwargs.beta,
143
145
  kwargs.deltatime.as_deref(),
146
+ kwargs.sequence_id_name.as_deref(),
147
+ kwargs.state_name.as_deref(),
144
148
  )
145
149
  }
@@ -0,0 +1,393 @@
1
+ // High-performance SGT implementation optimized for 100M+ records
2
+ // Uses group-based indexing (O(n)) and parallel processing with Rayon
3
+ use polars::prelude::*;
4
+ use rayon::prelude::*;
5
+ use std::collections::HashMap;
6
+
7
+ /// Time penalty modes for SGT
8
+ #[derive(Debug, Clone, Copy)]
9
+ pub enum TimePenalty {
10
+ Inverse,
11
+ Exponential,
12
+ Linear,
13
+ Power,
14
+ None,
15
+ }
16
+
17
+ impl TimePenalty {
18
+ pub fn from_str(s: &str) -> PolarsResult<Self> {
19
+ match s {
20
+ "inverse" => Ok(TimePenalty::Inverse),
21
+ "exponential" => Ok(TimePenalty::Exponential),
22
+ "linear" => Ok(TimePenalty::Linear),
23
+ "power" => Ok(TimePenalty::Power),
24
+ "none" => Ok(TimePenalty::None),
25
+ _ => polars_bail!(InvalidOperation: "Unknown time_penalty: {}", s),
26
+ }
27
+ }
28
+
29
+ #[inline(always)]
30
+ pub fn apply(&self, time_diff: f64, alpha: f64, beta: f64) -> f64 {
31
+ if time_diff == 0.0 {
32
+ return 1.0;
33
+ }
34
+ match self {
35
+ TimePenalty::Inverse => alpha / time_diff,
36
+ TimePenalty::Exponential => (-alpha * time_diff).exp(),
37
+ TimePenalty::Linear => (1.0 - alpha * time_diff).max(0.0),
38
+ TimePenalty::Power => 1.0 / time_diff.powf(beta),
39
+ TimePenalty::None => 1.0,
40
+ }
41
+ }
42
+ }
43
+
44
+ /// Normalization modes for SGT
45
+ #[derive(Debug, Clone, Copy)]
46
+ pub enum NormMode {
47
+ L1,
48
+ L2,
49
+ None,
50
+ }
51
+
52
+ impl NormMode {
53
+ pub fn from_str(s: &str) -> PolarsResult<Self> {
54
+ match s {
55
+ "l1" => Ok(NormMode::L1),
56
+ "l2" => Ok(NormMode::L2),
57
+ "none" => Ok(NormMode::None),
58
+ _ => polars_bail!(InvalidOperation: "Unknown mode: {}", s),
59
+ }
60
+ }
61
+
62
+ #[inline(always)]
63
+ pub fn normalize(&self, weights: &mut Vec<f64>) {
64
+ match self {
65
+ NormMode::L1 => {
66
+ let sum: f64 = weights.iter().sum();
67
+ if sum > 0.0 {
68
+ for weight in weights.iter_mut() {
69
+ *weight /= sum;
70
+ }
71
+ }
72
+ }
73
+ NormMode::L2 => {
74
+ let sum_sq: f64 = weights.iter().map(|w| w * w).sum();
75
+ if sum_sq > 0.0 {
76
+ let norm = sum_sq.sqrt();
77
+ for weight in weights.iter_mut() {
78
+ *weight /= norm;
79
+ }
80
+ }
81
+ }
82
+ NormMode::None => {}
83
+ }
84
+ }
85
+ }
86
+
87
+ /// Convert deltatime string to seconds multiplier
88
+ #[inline(always)]
89
+ fn deltatime_to_seconds(deltatime: Option<&str>) -> PolarsResult<f64> {
90
+ match deltatime {
91
+ None => Ok(1.0),
92
+ Some("s") => Ok(1.0),
93
+ Some("m") => Ok(60.0),
94
+ Some("h") => Ok(3600.0),
95
+ Some("d") => Ok(86400.0),
96
+ Some("w") => Ok(604800.0),
97
+ Some("month") => Ok(2629800.0),
98
+ Some("q") => Ok(7889400.0),
99
+ Some("y") => Ok(31557600.0),
100
+ Some(other) => polars_bail!(InvalidOperation: "Unknown deltatime: {}", other),
101
+ }
102
+ }
103
+
104
+ /// Extract time values as f64 for a batch of indices
105
+ #[inline]
106
+ fn extract_time_values(
107
+ series: &Series,
108
+ indices: &[usize],
109
+ deltatime: Option<&str>,
110
+ ) -> PolarsResult<Vec<Option<f64>>> {
111
+ let divisor = deltatime_to_seconds(deltatime)?;
112
+
113
+ match series.dtype() {
114
+ DataType::Datetime(time_unit, _) => {
115
+ let ca = series.datetime()?;
116
+ let time_unit_divisor = match time_unit {
117
+ TimeUnit::Nanoseconds => 1_000_000_000.0,
118
+ TimeUnit::Microseconds => 1_000_000.0,
119
+ TimeUnit::Milliseconds => 1_000.0,
120
+ };
121
+ Ok(indices
122
+ .iter()
123
+ .map(|&i| unsafe { ca.phys.get_unchecked(i) }.map(|v| v as f64 / time_unit_divisor / divisor))
124
+ .collect())
125
+ }
126
+ DataType::Date => {
127
+ let ca = series.date()?;
128
+ let date_divisor = divisor / 86400.0;
129
+ Ok(indices
130
+ .iter()
131
+ .map(|&i| unsafe { ca.phys.get_unchecked(i) }.map(|v| v as f64 / date_divisor))
132
+ .collect())
133
+ }
134
+ DataType::Duration(time_unit) => {
135
+ let ca = series.duration()?;
136
+ let time_unit_divisor = match time_unit {
137
+ TimeUnit::Nanoseconds => 1_000_000_000.0,
138
+ TimeUnit::Microseconds => 1_000_000.0,
139
+ TimeUnit::Milliseconds => 1_000.0,
140
+ };
141
+ Ok(indices
142
+ .iter()
143
+ .map(|&i| unsafe { ca.phys.get_unchecked(i) }.map(|v| v as f64 / time_unit_divisor / divisor))
144
+ .collect())
145
+ }
146
+ _ => {
147
+ let ca = series.cast(&DataType::Float64)?;
148
+ let f64_ca = ca.f64()?;
149
+ Ok(indices.iter().map(|&i| f64_ca.get(i)).collect())
150
+ }
151
+ }
152
+ }
153
+
154
+ /// Result for a single sequence
155
+ struct SequenceResult {
156
+ seq_id: String,
157
+ ngram_keys: Vec<String>,
158
+ ngram_values: Vec<f64>,
159
+ }
160
+
161
+ /// Generate n-grams with weights from a sequence (optimized version)
162
+ #[inline]
163
+ fn generate_ngrams_fast(
164
+ states: &[&str],
165
+ time_values: &[Option<f64>],
166
+ kappa: usize,
167
+ time_penalty: TimePenalty,
168
+ alpha: f64,
169
+ beta: f64,
170
+ ) -> (Vec<String>, Vec<f64>) {
171
+ if states.is_empty() {
172
+ return (Vec::new(), Vec::new());
173
+ }
174
+
175
+ // Estimate capacity: n-grams up to kappa for sequence of length L
176
+ // Total n-grams ≈ L + (L-1) + ... + (L-kappa+1)
177
+ let estimated_capacity = states.len() * kappa.min(states.len());
178
+ let mut ngram_weights: HashMap<String, f64> = HashMap::with_capacity(estimated_capacity);
179
+
180
+ // Generate n-grams up to kappa size
181
+ let max_n = kappa.min(states.len());
182
+ for n in 1..=max_n {
183
+ for i in 0..=(states.len() - n) {
184
+ // Build n-gram key efficiently
185
+ let ngram_key = if n == 1 {
186
+ states[i].to_string()
187
+ } else {
188
+ states[i..i + n].join(" -> ")
189
+ };
190
+
191
+ // Calculate weight based on time difference
192
+ let weight = if n > 1 && i + n - 1 < time_values.len() {
193
+ if let (Some(curr_time), Some(prev_time)) =
194
+ (time_values[i + n - 1], time_values[i + n - 2])
195
+ {
196
+ let time_diff = (curr_time - prev_time).abs();
197
+ time_penalty.apply(time_diff, alpha, beta)
198
+ } else {
199
+ 1.0
200
+ }
201
+ } else {
202
+ 1.0
203
+ };
204
+
205
+ *ngram_weights.entry(ngram_key).or_insert(0.0) += weight;
206
+ }
207
+ }
208
+
209
+ // Convert to sorted vectors
210
+ let mut keys: Vec<String> = ngram_weights.keys().cloned().collect();
211
+ keys.sort_unstable();
212
+ let values: Vec<f64> = keys.iter().map(|k| ngram_weights[k]).collect();
213
+
214
+ (keys, values)
215
+ }
216
+
217
+ /// Process a single sequence group
218
+ #[inline]
219
+ fn process_sequence(
220
+ seq_id: &str,
221
+ indices: &[usize],
222
+ states_ca: &StringChunked,
223
+ time_series: Option<&Series>,
224
+ kappa: usize,
225
+ length_sensitive: bool,
226
+ time_penalty: TimePenalty,
227
+ norm_mode: NormMode,
228
+ alpha: f64,
229
+ beta: f64,
230
+ deltatime: Option<&str>,
231
+ ) -> PolarsResult<Option<SequenceResult>> {
232
+ // Extract states for this sequence using direct index access
233
+ let states: Vec<&str> = indices
234
+ .iter()
235
+ .filter_map(|&i| states_ca.get(i))
236
+ .collect();
237
+
238
+ if states.is_empty() {
239
+ return Ok(None);
240
+ }
241
+
242
+ // Extract time values
243
+ let time_values = if let Some(ts) = time_series {
244
+ extract_time_values(ts, indices, deltatime)?
245
+ } else {
246
+ // Use index positions as time
247
+ indices.iter().map(|&i| Some(i as f64)).collect()
248
+ };
249
+
250
+ // Generate n-grams with weights
251
+ let (keys, mut values) = generate_ngrams_fast(
252
+ &states,
253
+ &time_values,
254
+ kappa,
255
+ time_penalty,
256
+ alpha,
257
+ beta,
258
+ );
259
+
260
+ // Apply length normalization if requested
261
+ if length_sensitive && states.len() > 1 {
262
+ let seq_len = states.len() as f64;
263
+ for weight in values.iter_mut() {
264
+ *weight /= seq_len;
265
+ }
266
+ }
267
+
268
+ // Apply normalization mode
269
+ norm_mode.normalize(&mut values);
270
+
271
+ Ok(Some(SequenceResult {
272
+ seq_id: seq_id.to_string(),
273
+ ngram_keys: keys,
274
+ ngram_values: values,
275
+ }))
276
+ }
277
+
278
+ /// High-performance SGT implementation using group-based indexing and parallel processing
279
+ #[allow(clippy::too_many_arguments)]
280
+ pub fn impl_sgt_transform(
281
+ inputs: &[Series],
282
+ kappa: i64,
283
+ length_sensitive: bool,
284
+ mode: &str,
285
+ time_penalty: &str,
286
+ alpha: f64,
287
+ beta: f64,
288
+ deltatime: Option<&str>,
289
+ sequence_id_name: Option<&str>,
290
+ state_name: Option<&str>,
291
+ ) -> PolarsResult<Series> {
292
+ if inputs.len() < 2 {
293
+ polars_bail!(InvalidOperation: "sgt_transform requires at least sequence_id and state columns");
294
+ }
295
+
296
+ let sequence_ids = inputs[0].cast(&DataType::String)?;
297
+ let states_series = &inputs[1];
298
+ let time_series = if inputs.len() > 2 {
299
+ Some(&inputs[2])
300
+ } else {
301
+ None
302
+ };
303
+
304
+ let kappa = kappa as usize;
305
+ let time_penalty_mode = TimePenalty::from_str(time_penalty)?;
306
+ let norm_mode = NormMode::from_str(mode)?;
307
+
308
+ let seq_ids_ca = sequence_ids.str()?;
309
+ let states_ca = states_series.str()?;
310
+
311
+ // OPTIMIZATION 1: Build group index in O(n) - single pass
312
+ // This replaces the O(n*m) nested loop
313
+ let mut groups: HashMap<&str, Vec<usize>> = HashMap::new();
314
+ for (idx, seq_id) in seq_ids_ca.iter().enumerate() {
315
+ if let Some(id) = seq_id {
316
+ groups.entry(id).or_default().push(idx);
317
+ }
318
+ }
319
+
320
+ // OPTIMIZATION 2: Process groups in parallel with Rayon
321
+ let results: Vec<PolarsResult<Option<SequenceResult>>> = groups
322
+ .par_iter()
323
+ .map(|(seq_id, indices)| {
324
+ process_sequence(
325
+ seq_id,
326
+ indices,
327
+ states_ca,
328
+ time_series,
329
+ kappa,
330
+ length_sensitive,
331
+ time_penalty_mode,
332
+ norm_mode,
333
+ alpha,
334
+ beta,
335
+ deltatime,
336
+ )
337
+ })
338
+ .collect();
339
+
340
+ // Collect successful results
341
+ let mut result_seq_ids: Vec<String> = Vec::with_capacity(groups.len());
342
+ let mut result_ngram_keys_list: Vec<Series> = Vec::with_capacity(groups.len());
343
+ let mut result_ngram_values_list: Vec<Series> = Vec::with_capacity(groups.len());
344
+
345
+ for result in results {
346
+ if let Some(seq_result) = result? {
347
+ result_seq_ids.push(seq_result.seq_id);
348
+ result_ngram_keys_list.push(
349
+ StringChunked::from_iter(seq_result.ngram_keys.iter().map(|s| Some(s.as_str())))
350
+ .into_series(),
351
+ );
352
+ result_ngram_values_list.push(
353
+ Float64Chunked::from_vec(PlSmallStr::EMPTY, seq_result.ngram_values).into_series(),
354
+ );
355
+ }
356
+ }
357
+
358
+ // Sort by sequence ID for deterministic output
359
+ let mut indexed: Vec<(usize, &String)> = result_seq_ids.iter().enumerate().collect();
360
+ indexed.sort_by(|a, b| a.1.cmp(b.1));
361
+
362
+ let sorted_seq_ids: Vec<String> = indexed.iter().map(|(i, _)| result_seq_ids[*i].clone()).collect();
363
+ let sorted_keys: Vec<Series> = indexed.iter().map(|(i, _)| result_ngram_keys_list[*i].clone()).collect();
364
+ let sorted_values: Vec<Series> = indexed.iter().map(|(i, _)| result_ngram_values_list[*i].clone()).collect();
365
+
366
+ // Use parameter names for struct fields (fallback to defaults)
367
+ let seq_field_name = sequence_id_name.unwrap_or("sequence_id");
368
+ let _state_field_name = state_name.unwrap_or("state"); // Reserved for future use
369
+
370
+ // Build result struct
371
+ let mut seq_id_ca = StringChunked::from_iter(sorted_seq_ids.iter().map(|s| Some(s.as_str())));
372
+ seq_id_ca.rename(PlSmallStr::from_str(seq_field_name));
373
+ let seq_id_series = seq_id_ca.into_series();
374
+
375
+ // Convert to list series
376
+ let ngram_keys_dtype = DataType::List(Box::new(DataType::String));
377
+ let ngram_keys_series = Series::new(PlSmallStr::from_str("ngram_keys"), sorted_keys)
378
+ .cast(&ngram_keys_dtype)?;
379
+
380
+ // Renamed from ngram_values to value
381
+ let ngram_values_dtype = DataType::List(Box::new(DataType::Float64));
382
+ let ngram_values_series = Series::new(PlSmallStr::from_str("value"), sorted_values)
383
+ .cast(&ngram_values_dtype)?;
384
+
385
+ // Create struct
386
+ let struct_fields = [seq_id_series, ngram_keys_series, ngram_values_series];
387
+ Ok(StructChunked::from_series(
388
+ PlSmallStr::from_str("sgt_result"),
389
+ sorted_seq_ids.len(),
390
+ struct_fields.iter(),
391
+ )?
392
+ .into_series())
393
+ }
@@ -113,7 +113,7 @@ def test_sgt_time_penalty_exponential() -> None:
113
113
  )
114
114
 
115
115
  weights = result.select(
116
- pl.col("sgt").struct.field("ngram_values")
116
+ pl.col("sgt").struct.field("value")
117
117
  ).to_series().to_list()[0]
118
118
 
119
119
  # Weights should be positive
@@ -187,7 +187,7 @@ def test_sgt_time_penalty_none() -> None:
187
187
 
188
188
  # With no penalty, all weights should be integer counts
189
189
  weights = result.select(
190
- pl.col("sgt").struct.field("ngram_values")
190
+ pl.col("sgt").struct.field("value")
191
191
  ).to_series().to_list()[0]
192
192
 
193
193
  assert all(w > 0 for w in weights)
@@ -210,7 +210,7 @@ def test_sgt_l1_normalization() -> None:
210
210
  )
211
211
 
212
212
  weights = result.select(
213
- pl.col("sgt").struct.field("ngram_values")
213
+ pl.col("sgt").struct.field("value")
214
214
  ).to_series().to_list()[0]
215
215
 
216
216
  # L1 normalization: sum should be 1.0
@@ -234,7 +234,7 @@ def test_sgt_l2_normalization() -> None:
234
234
  )
235
235
 
236
236
  weights = result.select(
237
- pl.col("sgt").struct.field("ngram_values")
237
+ pl.col("sgt").struct.field("value")
238
238
  ).to_series().to_list()[0]
239
239
 
240
240
  # L2 normalization: sum of squares should be 1.0
@@ -260,7 +260,7 @@ def test_sgt_length_sensitive() -> None:
260
260
  )
261
261
 
262
262
  weights = result.select(
263
- pl.col("sgt").struct.field("ngram_values")
263
+ pl.col("sgt").struct.field("value")
264
264
  ).to_series().to_list()[0]
265
265
 
266
266
  # With length normalization, weights should be divided by sequence length
@@ -341,7 +341,7 @@ def test_sgt_struct_output() -> None:
341
341
  expanded = result.select([
342
342
  pl.col("sgt").struct.field("sequence_id").alias("seq_id"),
343
343
  pl.col("sgt").struct.field("ngram_keys").alias("keys"),
344
- pl.col("sgt").struct.field("ngram_values").alias("values"),
344
+ pl.col("sgt").struct.field("value").alias("values"),
345
345
  ])
346
346
 
347
347
  assert "seq_id" in expanded.columns
@@ -372,7 +372,7 @@ def test_sgt_explode_pattern() -> None:
372
372
  exploded = result.select([
373
373
  pl.col("sgt").struct.field("sequence_id"),
374
374
  pl.col("sgt").struct.field("ngram_keys").alias("ngram"),
375
- pl.col("sgt").struct.field("ngram_values").alias("weight"),
375
+ pl.col("sgt").struct.field("value").alias("weight"),
376
376
  ]).explode(["ngram", "weight"])
377
377
 
378
378
  assert exploded.shape[0] > 0
@@ -1,304 +0,0 @@
1
- // Simplified SGT implementation that actually compiles
2
- // This implementation works correctly with POL ARS API patterns
3
- use polars::prelude::*;
4
- use std::collections::HashMap;
5
-
6
- /// Time penalty modes for SGT
7
- #[derive(Debug, Clone)]
8
- pub enum TimePenalty {
9
- Inverse,
10
- Exponential,
11
- Linear,
12
- Power,
13
- None,
14
- }
15
-
16
- impl TimePenalty {
17
- pub fn from_str(s: &str) -> PolarsResult<Self> {
18
- match s {
19
- "inverse" => Ok(TimePenalty::Inverse),
20
- "exponential" => Ok(TimePenalty::Exponential),
21
- "linear" => Ok(TimePenalty::Linear),
22
- "power" => Ok(TimePenalty::Power),
23
- "none" => Ok(TimePenalty::None),
24
- _ => polars_bail!(InvalidOperation: "Unknown time_penalty: {}", s),
25
- }
26
- }
27
-
28
- pub fn apply(&self, time_diff: f64, alpha: f64, beta: f64) -> f64 {
29
- if time_diff == 0.0 {
30
- return 1.0;
31
- }
32
- match self {
33
- TimePenalty::Inverse => alpha / time_diff,
34
- TimePenalty::Exponential => (-alpha * time_diff).exp(),
35
- TimePenalty::Linear => (1.0 - alpha * time_diff).max(0.0),
36
- TimePenalty::Power => 1.0 / time_diff.powf(beta),
37
- TimePenalty::None => 1.0,
38
- }
39
- }
40
- }
41
-
42
- /// Normalization modes for SGT
43
- #[derive(Debug, Clone)]
44
- pub enum NormMode {
45
- L1,
46
- L2,
47
- None,
48
- }
49
-
50
- impl NormMode {
51
- pub fn from_str(s: &str) -> PolarsResult<Self> {
52
- match s {
53
- "l1" => Ok(NormMode::L1),
54
- "l2" => Ok(NormMode::L2),
55
- "none" => Ok(NormMode::None),
56
- _ => polars_bail!(InvalidOperation: "Unknown mode: {}", s),
57
- }
58
- }
59
-
60
- pub fn normalize(&self, weights: &mut HashMap<String, f64>) {
61
- match self {
62
- NormMode::L1 => {
63
- let sum: f64 = weights.values().sum();
64
- if sum > 0.0 {
65
- for weight in weights.values_mut() {
66
- *weight /= sum;
67
- }
68
- }
69
- }
70
- NormMode::L2 => {
71
- let sum_sq: f64 = weights.values().map(|w| w * w).sum();
72
- if sum_sq > 0.0 {
73
- let norm = sum_sq.sqrt();
74
- for weight in weights.values_mut() {
75
- *weight /= norm;
76
- }
77
- }
78
- }
79
- NormMode::None => {}
80
- }
81
- }
82
- }
83
-
84
- /// Convert deltatime string to seconds multiplier
85
- fn deltatime_to_seconds(deltatime: Option<&str>) -> PolarsResult<f64> {
86
- match deltatime {
87
- None => Ok(1.0),
88
- Some("s") => Ok(1.0),
89
- Some("m") => Ok(60.0),
90
- Some("h") => Ok(3600.0),
91
- Some("d") => Ok(86400.0),
92
- Some("w") => Ok(604800.0),
93
- Some("month") => Ok(2629800.0), // 30.44 days
94
- Some("q") => Ok(7889400.0), // 91.31 days
95
- Some("y") => Ok(31557600.0), // 365.25 days
96
- Some(other) => polars_bail!(InvalidOperation: "Unknown deltatime: {}", other),
97
- }
98
- }
99
-
100
- /// Extract time value as f64
101
- fn get_time_value(series: &Series, idx: usize, deltatime: Option<&str>) -> PolarsResult<Option<f64>> {
102
- match series.dtype() {
103
- DataType::Datetime(time_unit, _) => {
104
- let ca = series.datetime()?;
105
- let divisor = deltatime_to_seconds(deltatime)?;
106
- let time_unit_divisor = match time_unit {
107
- TimeUnit::Nanoseconds => 1_000_000_000.0,
108
- TimeUnit::Microseconds => 1_000_000.0,
109
- TimeUnit::Milliseconds => 1_000.0,
110
- };
111
- Ok(unsafe { ca.phys.get_unchecked(idx) }.map(|v| v as f64 / time_unit_divisor / divisor))
112
- }
113
- DataType::Date => {
114
- let ca = series.date()?;
115
- let divisor = deltatime_to_seconds(deltatime)? / 86400.0;
116
- Ok(unsafe { ca.phys.get_unchecked(idx) }.map(|v| v as f64 / divisor))
117
- }
118
- DataType::Duration(time_unit) => {
119
- let ca = series.duration()?;
120
- let divisor = deltatime_to_seconds(deltatime)?;
121
- let time_unit_divisor = match time_unit {
122
- TimeUnit::Nanoseconds => 1_000_000_000.0,
123
- TimeUnit::Microseconds => 1_000_000.0,
124
- TimeUnit::Milliseconds => 1_000.0,
125
- };
126
- Ok(unsafe { ca.phys.get_unchecked(idx) }.map(|v| v as f64 / time_unit_divisor / divisor))
127
- }
128
- _ => {
129
- let ca = series.cast(&DataType::Float64)?;
130
- Ok(ca.f64()?.get(idx))
131
- }
132
- }
133
- }
134
-
135
- /// Generate n-grams with weights from a sequence
136
- fn generate_ngrams(
137
- states: &[String],
138
- time_values: &[Option<f64>],
139
- kappa: usize,
140
- time_penalty: &TimePenalty,
141
- alpha: f64,
142
- beta: f64,
143
- ) -> HashMap<String, f64> {
144
- let mut ngram_weights: HashMap<String, f64> = HashMap::new();
145
-
146
- if states.is_empty() {
147
- return ngram_weights;
148
- }
149
-
150
- // Generate n-grams up to kappa size
151
- for n in 1..=kappa.min(states.len()) {
152
- for i in 0..=(states.len() - n) {
153
- let ngram: Vec<&str> = states[i..i + n].iter().map(|s| s.as_str()).collect();
154
- let ngram_key = ngram.join(" -> ");
155
-
156
- // Calculate weight based on time difference
157
- let weight = if n > 1 && i + n - 1 < time_values.len() {
158
- if let (Some(curr_time), Some(prev_time)) = (time_values[i + n - 1], time_values[i + n - 2]) {
159
- let time_diff = (curr_time - prev_time).abs();
160
- time_penalty.apply(time_diff, alpha, beta)
161
- } else {
162
- 1.0
163
- }
164
- } else {
165
- 1.0
166
- };
167
-
168
- *ngram_weights.entry(ngram_key).or_insert(0.0) += weight;
169
- }
170
- }
171
-
172
- ngram_weights
173
- }
174
-
175
- /// Main SGT implementation using simple iteration
176
- #[allow(clippy::too_many_arguments)]
177
- pub fn impl_sgt_transform(
178
- inputs: &[Series],
179
- kappa: i64,
180
- length_sensitive: bool,
181
- mode: &str,
182
- time_penalty: &str,
183
- alpha: f64,
184
- beta: f64,
185
- deltatime: Option<&str>,
186
- ) -> PolarsResult<Series> {
187
- if inputs.len() < 2 {
188
- polars_bail!(InvalidOperation: "sgt_transform requires at least sequence_id and state columns");
189
- }
190
-
191
- let sequence_ids = inputs[0].cast(&DataType::String)?;
192
- let states_series = &inputs[1];
193
- let time_series = if inputs.len() > 2 {
194
- Some(&inputs[2])
195
- } else {
196
- None
197
- };
198
-
199
- let kappa = kappa as usize;
200
- let time_penalty_mode = TimePenalty::from_str(time_penalty)?;
201
- let norm_mode = NormMode::from_str(mode)?;
202
-
203
- // Get unique sequence IDs
204
- let unique_ids: StringChunked = sequence_ids.str()?.unique()?.sort(Default::default());
205
-
206
- let mut result_seq_ids: Vec<String> = Vec::new();
207
- let mut result_ngram_keys_list: Vec<Series> = Vec::new();
208
- let mut result_ngram_values_list: Vec<Series> = Vec::new();
209
-
210
- let seq_ids_ca = sequence_ids.str()?;
211
- let states_ca = states_series.str()?;
212
-
213
- // Process each unique sequence ID
214
- for idx in 0..unique_ids.len() {
215
- let seq_id: &str = match unique_ids.get(idx) {
216
- Some(id) => id,
217
- None => continue,
218
- };
219
-
220
- // Find all rows matching this sequence ID
221
- let mask: BooleanChunked = seq_ids_ca.equal(seq_id);
222
-
223
- // Extract states for this sequence
224
- let mut sequence_states = Vec::new();
225
- let mut time_values = Vec::new();
226
-
227
- for i in 0..mask.len() {
228
- if mask.get(i).unwrap_or(false) {
229
- if let Some(state) = states_ca.get(i) {
230
- sequence_states.push(state.to_string());
231
- if let Some(ts) = time_series {
232
- time_values.push(get_time_value(ts, i, deltatime)?);
233
- } else {
234
- time_values.push(Some(i as f64));
235
- }
236
- }
237
- }
238
- }
239
-
240
- if sequence_states.is_empty() {
241
- continue;
242
- }
243
-
244
- // Generate n-grams with weights
245
- let mut ngram_weights = generate_ngrams(
246
- &sequence_states,
247
- &time_values,
248
- kappa,
249
- &time_penalty_mode,
250
- alpha,
251
- beta,
252
- );
253
-
254
- // Apply length normalization if requested
255
- if length_sensitive && sequence_states.len() > 1 {
256
- let seq_len = sequence_states.len() as f64;
257
- for weight in ngram_weights.values_mut() {
258
- *weight /= seq_len;
259
- }
260
- }
261
-
262
- // Apply normalization mode
263
- norm_mode.normalize(&mut ngram_weights);
264
-
265
- // Convert to sorted vectors
266
- let mut keys: Vec<String> = ngram_weights.keys().cloned().collect();
267
- keys.sort();
268
- let values: Vec<f64> = keys.iter().map(|k| ngram_weights[k]).collect();
269
-
270
- result_seq_ids.push(seq_id.to_string());
271
- result_ngram_keys_list.push(
272
- StringChunked::from_iter(keys.iter().map(|s| Some(s.as_str()))).into_series()
273
- );
274
- result_ngram_values_list.push(
275
- Float64Chunked::from_vec(PlSmallStr::EMPTY, values).into_series()
276
- );
277
- }
278
-
279
- // Build result struct
280
- let mut seq_id_ca = StringChunked::from_iter(result_seq_ids.iter().map(|s| Some(s.as_str())));
281
- seq_id_ca.rename(PlSmallStr::from_str("sequence_id"));
282
- let seq_id_series = seq_id_ca.into_series();
283
-
284
- // Convert to list series
285
- let ngram_keys_dtype = DataType::List(Box::new(DataType::String));
286
- let ngram_keys_series = Series::new(
287
- PlSmallStr::from_str("ngram_keys"),
288
- result_ngram_keys_list
289
- ).cast(&ngram_keys_dtype)?;
290
-
291
- let ngram_values_dtype = DataType::List(Box::new(DataType::Float64));
292
- let ngram_values_series = Series::new(
293
- PlSmallStr::from_str("ngram_values"),
294
- result_ngram_values_list
295
- ).cast(&ngram_values_dtype)?;
296
-
297
- // Create struct
298
- let struct_fields = [seq_id_series, ngram_keys_series, ngram_values_series];
299
- Ok(StructChunked::from_series(
300
- PlSmallStr::from_str("sgt_result"),
301
- result_seq_ids.len(),
302
- struct_fields.iter()
303
- )?.into_series())
304
- }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes