febolt 0.1.56__tar.gz → 0.1.58__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -362,7 +362,7 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
362
362
 
363
363
  [[package]]
364
364
  name = "febolt"
365
- version = "0.1.56"
365
+ version = "0.1.58"
366
366
  dependencies = [
367
367
  "blas",
368
368
  "intel-mkl-src",
@@ -1046,9 +1046,9 @@ dependencies = [
1046
1046
 
1047
1047
  [[package]]
1048
1048
  name = "openssl"
1049
- version = "0.10.70"
1049
+ version = "0.10.71"
1050
1050
  source = "registry+https://github.com/rust-lang/crates.io-index"
1051
- checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6"
1051
+ checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
1052
1052
  dependencies = [
1053
1053
  "bitflags",
1054
1054
  "cfg-if",
@@ -1087,9 +1087,9 @@ dependencies = [
1087
1087
 
1088
1088
  [[package]]
1089
1089
  name = "openssl-sys"
1090
- version = "0.9.105"
1090
+ version = "0.9.106"
1091
1091
  source = "registry+https://github.com/rust-lang/crates.io-index"
1092
- checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc"
1092
+ checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
1093
1093
  dependencies = [
1094
1094
  "cc",
1095
1095
  "libc",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "febolt"
3
- version = "0.1.56"
3
+ version = "0.1.58"
4
4
  edition = "2021"
5
5
  description = "Statistics library for Python powered by Rust"
6
6
  license = "MIT"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: febolt
3
- Version: 0.1.56
3
+ Version: 0.1.58
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "febolt"
7
- version = "0.1.56"
7
+ version = "0.1.58"
8
8
  requires-python = ">=3.8"
9
9
  description = "A Rust-based Statistics and ML package, callable from Python."
10
10
  keywords = ["rust", "python", "Machine Learning", "Statistics", "pyo3"]
@@ -0,0 +1,306 @@
1
+ // src/lib.rs
2
+
3
+ use pyo3::prelude::*;
4
+ use pyo3::types::{IntoPyDict, PyAny, PyDict, PyList, PyType};
5
+ use pyo3::wrap_pyfunction;
6
+
7
+ use numpy::{PyArray1, PyArray2, IntoPyArray};
8
+ use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s, Axis};
9
+ use statrs::distribution::{Continuous, Normal, ContinuousCDF};
10
+ use std::collections::HashMap;
11
+ use std::fmt::Write as _; // for building cluster-keys
12
+
13
+ /// Add significance stars.
14
+ fn add_significance_stars(p: f64) -> &'static str {
15
+ if p < 0.01 {
16
+ "***"
17
+ } else if p < 0.05 {
18
+ "**"
19
+ } else if p < 0.1 {
20
+ "*"
21
+ } else {
22
+ ""
23
+ }
24
+ }
25
+
26
+ /// A tiny helper to check if we have a Logit or Probit from statsmodels
27
+ fn detect_model_type(model: &PyAny) -> Result<bool, PyErr> {
28
+ // Return is_logit: true if logit, false if probit
29
+ let model_obj = model.getattr("model").unwrap_or(model);
30
+ let cls: String = model_obj
31
+ .getattr("__class__")?
32
+ .getattr("__name__")?
33
+ .extract()?;
34
+ let lc = cls.to_lowercase();
35
+ if lc == "logit" {
36
+ Ok(true) // is_logit = true
37
+ } else if lc == "probit" {
38
+ Ok(false) // is_logit = false
39
+ } else {
40
+ Err(pyo3::exceptions::PyValueError::new_err(
41
+ format!("ame: only Logit or Probit supported, got {cls}"),
42
+ ))
43
+ }
44
+ }
45
+
46
+ /// Downcast a Python object to NumPy PyArray2<f64> => ndarray::ArrayView2<f64>.
47
+ ///
48
+ /// Marked unsafe because .as_array() in pyo3-numpy is unsafe, trusting Python memory is valid.
49
+ fn as_array2_f64<'py>(obj: &'py PyAny) -> PyResult<ArrayView2<'py, f64>> {
50
+ let pyarray = obj.downcast::<PyArray2<f64>>()?;
51
+ let view = unsafe { pyarray.as_array() };
52
+ Ok(view)
53
+ }
54
+
55
+ /// Similarly, for 1D arrays.
56
+ fn as_array1_f64<'py>(obj: &'py PyAny) -> PyResult<ArrayView1<'py, f64>> {
57
+ let pyarray = obj.downcast::<PyArray1<f64>>()?;
58
+ let view = unsafe { pyarray.as_array() };
59
+ Ok(view)
60
+ }
61
+
62
+ /// Evaluate logistic or normal cdf
63
+ fn cdf_logit_probit(is_logit: bool, z: f64) -> f64 {
64
+ if is_logit {
65
+ // logistic cdf
66
+ 1.0 / (1.0 + (-z).exp())
67
+ } else {
68
+ // normal cdf
69
+ let dist = Normal::new(0.0, 1.0).unwrap();
70
+ dist.cdf(z)
71
+ }
72
+ }
73
+ /// Evaluate logistic or normal pdf
74
+ fn pdf_logit_probit(is_logit: bool, z: f64) -> f64 {
75
+ if is_logit {
76
+ // logistic pdf
77
+ let e = z.exp();
78
+ e / (1.0 + e).powi(2)
79
+ } else {
80
+ // normal pdf
81
+ (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt()
82
+ }
83
+ }
84
+
85
+ /// The main function that calculates Logit/Probit AME in the same style as the original Probit code
86
+ #[pyfunction]
87
+ fn ame<'py>(
88
+ py: Python<'py>,
89
+ model: &'py PyAny, // Could be Logit or Probit
90
+ chunk_size: Option<usize>, // optional chunk
91
+ ) -> PyResult<&'py PyAny> {
92
+ // 1) detect Logit vs Probit
93
+ let is_logit = detect_model_type(model)?;
94
+
95
+ // 2) read params (handle pandas Series)
96
+ let params_obj = model.getattr("params")?;
97
+ let params_pyarray = if let Ok(values) = params_obj.getattr("values") {
98
+ values.downcast::<PyArray1<f64>>()?
99
+ } else {
100
+ params_obj.downcast::<PyArray1<f64>>()?
101
+ };
102
+ let beta = unsafe { params_pyarray.as_array() };
103
+
104
+ // 3) read cov (handle pandas DataFrame)
105
+ let cov_obj = model.call_method0("cov_params")?;
106
+ let cov_pyarray = if let Ok(values) = cov_obj.getattr("values") {
107
+ values.downcast::<PyArray2<f64>>()?
108
+ } else {
109
+ cov_obj.downcast::<PyArray2<f64>>()?
110
+ };
111
+ let cov_beta = unsafe { cov_pyarray.as_array() };
112
+
113
+ // 4) Get model object and handle exog (X) and exog_names
114
+ let model_obj = model.getattr("model").unwrap_or(model);
115
+
116
+ // Handle pandas DataFrame input
117
+ let exog_py = model_obj.getattr("exog")?;
118
+ let (x_pyarray, exog_names) = if let Ok(values) = exog_py.getattr("values") {
119
+ // Pandas DataFrame path
120
+ (
121
+ values.downcast::<PyArray2<f64>>()?,
122
+ exog_py.getattr("columns")?.extract::<Vec<String>>()?
123
+ )
124
+ } else {
125
+ // Numpy array path
126
+ (
127
+ exog_py.downcast::<PyArray2<f64>>()?,
128
+ model_obj.getattr("exog_names")?.extract::<Vec<String>>()?
129
+ )
130
+ };
131
+
132
+ let X = unsafe { x_pyarray.as_array() };
133
+ let (n, k) = (X.nrows(), X.ncols());
134
+ let chunk = chunk_size.unwrap_or(n);
135
+
136
+ // 5) Identify intercept columns
137
+ let intercept_indices: Vec<usize> = exog_names
138
+ .iter()
139
+ .enumerate()
140
+ .filter_map(|(i, nm)| {
141
+ let ln = nm.to_lowercase();
142
+ if ln == "const" || ln == "intercept" {
143
+ Some(i)
144
+ } else {
145
+ None
146
+ }
147
+ })
148
+ .collect();
149
+
150
+ // 6) Identify discrete columns => strictly 0/1
151
+ let is_discrete: Vec<usize> = exog_names
152
+ .iter()
153
+ .enumerate()
154
+ .filter_map(|(j, _)| {
155
+ if intercept_indices.contains(&j) {
156
+ None
157
+ } else {
158
+ let col_j = X.column(j);
159
+ if col_j.iter().all(|&v| v == 0.0 || v == 1.0) {
160
+ Some(j)
161
+ } else {
162
+ None
163
+ }
164
+ }
165
+ })
166
+ .collect();
167
+
168
+ // 7) Prepare accumulators
169
+ let mut sum_ame = vec![0.0; k]; // sum partial effects
170
+ let mut partial_jl_sums = vec![0.0; k * k];
171
+ let normal = Normal::new(0.0, 1.0).unwrap();
172
+
173
+ // 8) single pass with chunk
174
+ let mut idx_start = 0;
175
+ while idx_start < n {
176
+ let idx_end = (idx_start + chunk).min(n);
177
+ let x_chunk = X.slice(s![idx_start..idx_end, ..]);
178
+ let z_chunk = x_chunk.dot(&beta); // shape(n_chunk)
179
+
180
+ // pdf => we might do partial for continuous
181
+ let pdf_chunk = z_chunk.mapv(|z| pdf_logit_probit(is_logit, z));
182
+
183
+ // handle discrete
184
+ for &j in &is_discrete {
185
+ let xj_col = x_chunk.column(j);
186
+ let b_j = beta[j];
187
+ // z_j1 => z + (1-xj)*b_j
188
+ // z_j0 => z - xj*b_j
189
+ let delta_j1 = (1.0 - &xj_col).mapv(|x| x * b_j);
190
+ let delta_j0 = xj_col.mapv(|x| -x * b_j);
191
+ let z_j1 = &z_chunk + &delta_j1;
192
+ let z_j0 = &z_chunk + &delta_j0;
193
+
194
+ let cdf_j1 = z_j1.mapv(|z| cdf_logit_probit(is_logit, z));
195
+ let cdf_j0 = z_j0.mapv(|z| cdf_logit_probit(is_logit, z));
196
+ // sum
197
+ let effect_sum = cdf_j1.sum() - cdf_j0.sum();
198
+ sum_ame[j] += effect_sum;
199
+
200
+ // partial_jl_sums => row j, col l
201
+ let pdf_j1 = z_j1.mapv(|z| pdf_logit_probit(is_logit, z));
202
+ let pdf_j0 = z_j0.mapv(|z| pdf_logit_probit(is_logit, z));
203
+ for l in 0..k {
204
+ let grad = if l == j {
205
+ // special case
206
+ pdf_j1.sum()
207
+ } else {
208
+ let x_l = x_chunk.column(l);
209
+ let diff_pdf = &pdf_j1 - &pdf_j0;
210
+ diff_pdf.dot(&x_l)
211
+ };
212
+ partial_jl_sums[j * k + l] += grad;
213
+ }
214
+ }
215
+
216
+ // handle continuous
217
+ for j in 0..k {
218
+ if intercept_indices.contains(&j) || is_discrete.contains(&j) {
219
+ continue;
220
+ }
221
+ let b_j = beta[j];
222
+ // sum_ame
223
+ sum_ame[j] += b_j * pdf_chunk.sum();
224
+ // partial_jl_sums => row j, col l
225
+ for l in 0..k {
226
+ let grad = if j == l {
227
+ pdf_chunk.sum()
228
+ } else {
229
+ // - b_j * sum(z_chunk * x_col(l) * pdf_chunk)
230
+ let x_l = x_chunk.column(l);
231
+ // careful about sign from the original code
232
+ -b_j * (&z_chunk * &x_l).dot(&pdf_chunk)
233
+ };
234
+ partial_jl_sums[j * k + l] += grad;
235
+ }
236
+ }
237
+
238
+ idx_start = idx_end;
239
+ }
240
+
241
+ // 9) average sums
242
+ let ame: Vec<f64> = sum_ame.iter().map(|v| v / (n as f64)).collect();
243
+
244
+ // gradient matrix => shape(k,k)
245
+ let mut grad_ame = Array2::<f64>::zeros((k,k));
246
+ for j in 0..k {
247
+ for l in 0..k {
248
+ grad_ame[[j,l]] = partial_jl_sums[j * k + l] / (n as f64);
249
+ }
250
+ }
251
+
252
+ // cov => grad_ame * cov_beta * grad_ame^T
253
+ let cov_ame = grad_ame.dot(&cov_beta).dot(&grad_ame.t());
254
+ let var_ame: Vec<f64> = cov_ame.diag().iter().map(|&v| v.max(0.0)).collect();
255
+ let se_ame: Vec<f64> = var_ame.iter().map(|&v| v.sqrt()).collect();
256
+
257
+ // 10) Build final results
258
+ let (mut dy_dx, mut se_err, mut z_vals, mut p_vals, mut sig) =
259
+ (Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new());
260
+ let mut names_out = Vec::new();
261
+
262
+ for j in 0..k {
263
+ if intercept_indices.contains(&j) {
264
+ continue;
265
+ }
266
+ let dy = ame[j];
267
+ let s = se_ame[j];
268
+ dy_dx.push(dy);
269
+ se_err.push(s);
270
+ if s > 1e-15 {
271
+ let z = dy / s;
272
+ let p = 2.0*(1.0 - normal.cdf(z.abs()));
273
+ z_vals.push(z);
274
+ p_vals.push(p);
275
+ sig.push(add_significance_stars(p));
276
+ } else {
277
+ z_vals.push(0.0);
278
+ p_vals.push(1.0);
279
+ sig.push("");
280
+ }
281
+ names_out.push(exog_names[j].clone());
282
+ }
283
+
284
+ // 11) Create DataFrame
285
+ let pd = py.import("pandas")?;
286
+ let data = PyDict::new(py);
287
+ data.set_item("dy/dx", &dy_dx)?;
288
+ data.set_item("Std. Err", &se_err)?;
289
+ data.set_item("z", &z_vals)?;
290
+ data.set_item("Pr(>|z|)", &p_vals)?;
291
+ data.set_item("Significance", &sig)?;
292
+
293
+ let kwargs = PyDict::new(py);
294
+ kwargs.set_item("data", data)?;
295
+ kwargs.set_item("index", &names_out)?;
296
+
297
+ let df = pd.call_method("DataFrame", (), Some(kwargs))?;
298
+ Ok(df)
299
+ }
300
+
301
+
302
+ #[pymodule]
303
+ fn febolt(_py: Python, m: &PyModule) -> PyResult<()> {
304
+ m.add_function(wrap_pyfunction!(ame, m)?)?;
305
+ Ok(())
306
+ }
febolt-0.1.56/src/lib.rs DELETED
@@ -1,281 +0,0 @@
1
- use pyo3::prelude::*;
2
- use pyo3::types::{PyAny, PyDict};
3
- use pyo3::wrap_pyfunction;
4
-
5
- use numpy::{PyArray1, PyArray2};
6
- use ndarray::{Array1, Array2, Axis, s};
7
- use statrs::distribution::{Normal, ContinuousCDF};
8
- use std::f64;
9
-
10
- /// Distinguish whether statsmodels model is Logit or Probit
11
- fn is_logit_model(model: &PyAny) -> PyResult<bool> {
12
- // Return true if Logit, false if Probit
13
- let model_obj = model.getattr("model").unwrap_or(model);
14
- let cls_name: String = model_obj.getattr("__class__")?.getattr("__name__")?.extract()?;
15
- let lower = cls_name.to_lowercase();
16
- if lower == "logit" {
17
- Ok(true)
18
- } else if lower == "probit" {
19
- Ok(false)
20
- } else {
21
- Err(pyo3::exceptions::PyValueError::new_err(
22
- format!("ame() only supports Logit or Probit. Got: {cls_name}"),
23
- ))
24
- }
25
- }
26
-
27
- /// Evaluate logistic or normal cdf
28
- fn cdf_logit_probit(is_logit: bool, z: f64) -> f64 {
29
- if is_logit {
30
- // logistic cdf => 1/(1+exp(-z))
31
- 1.0 / (1.0 + (-z).exp())
32
- } else {
33
- // normal cdf
34
- let dist = Normal::new(0.0, 1.0).unwrap();
35
- dist.cdf(z)
36
- }
37
- }
38
-
39
- /// Evaluate logistic or normal pdf
40
- fn pdf_logit_probit(is_logit: bool, z: f64) -> f64 {
41
- if is_logit {
42
- let e = z.exp();
43
- e / (1.0 + e).powi(2)
44
- } else {
45
- (-0.5 * z * z).exp() / (2.0 * f64::consts::PI).sqrt()
46
- }
47
- }
48
-
49
- /// Significance stars
50
- fn add_significance_stars(p: f64) -> &'static str {
51
- if p < 0.001 {
52
- "***"
53
- } else if p < 0.01 {
54
- "**"
55
- } else if p < 0.05 {
56
- "*"
57
- } else if p < 0.1 {
58
- "."
59
- } else {
60
- ""
61
- }
62
- }
63
-
64
- /// Main AME function
65
- #[pyfunction]
66
- fn ame<'py>(
67
- py: Python<'py>,
68
- model: &'py PyAny, // statsmodels fitted results
69
- chunk_size: Option<usize>,
70
- ) -> PyResult<&'py PyAny> {
71
- // 1) Detect Logit vs Probit
72
- let is_logit = is_logit_model(model)?;
73
-
74
- // 2) Extract 1D params => shape(k,)
75
- let params_obj = model.getattr("params")?;
76
- // Usually `params` is already a 1D numpy array, so:
77
- let params_nd = params_obj.downcast::<PyArray1<f64>>()?;
78
- let beta_view = unsafe { params_nd.as_array() };
79
- let beta = beta_view.to_owned(); // shape(k)
80
-
81
- let k = beta.len();
82
-
83
- // 3) Extract covariance => shape(k,k)
84
- let cov_py = model.call_method0("cov_params")?.downcast::<PyArray2<f64>>()?;
85
- let cov_view = unsafe { cov_py.as_array() };
86
- if cov_view.nrows() != k || cov_view.ncols() != k {
87
- return Err(pyo3::exceptions::PyValueError::new_err(
88
- "cov_params dimension mismatch with `params`"
89
- ));
90
- }
91
- let cov_beta = cov_view.to_owned();
92
-
93
- // 4) Extract exog_names from statsmodels
94
- let model_obj = model.getattr("model")?;
95
- let exog_names: Vec<String> = model_obj.getattr("exog_names")?.extract()?;
96
-
97
- // 5) Extract exog
98
- // In many cases, `model_obj.getattr("exog")` is already a NumPy array.
99
- // But if it's a DataFrame, we forcibly do .to_numpy().
100
- let exog_obj = model_obj.getattr("exog")?;
101
- // We'll try direct downcast first:
102
- let exog_pyarray = match exog_obj.downcast::<PyArray2<f64>>() {
103
- Ok(arr) => arr,
104
- Err(_e) => {
105
- // fallback => call .to_numpy()
106
- let np = exog_obj.call_method0("to_numpy")?;
107
- np.downcast::<PyArray2<f64>>()?
108
- }
109
- };
110
- let X_view = unsafe { exog_pyarray.as_array() };
111
- let (n, k_exog) = X_view.dim();
112
- if k_exog != k {
113
- return Err(pyo3::exceptions::PyValueError::new_err(
114
- format!("exog dimension mismatch: exog has {} columns, params has length={}", k_exog, k)
115
- ));
116
- }
117
-
118
- let chunk = chunk_size.unwrap_or(n);
119
-
120
- // 6) Identify intercept columns
121
- let intercept_idx: Vec<usize> = exog_names
122
- .iter()
123
- .enumerate()
124
- .filter_map(|(i, nm)| {
125
- let ln = nm.to_lowercase();
126
- if ln=="intercept" || ln=="const" {
127
- Some(i)
128
- } else {
129
- None
130
- }
131
- })
132
- .collect();
133
-
134
- // 7) Identify discrete columns => strictly 0/1
135
- let mut disc_idx = Vec::new();
136
- for j in 0..k {
137
- if intercept_idx.contains(&j) {
138
- continue;
139
- }
140
- let col_j = X_view.column(j);
141
- if col_j.iter().all(|&v| v==0.0 || v==1.0) {
142
- disc_idx.push(j);
143
- }
144
- }
145
-
146
- // 8) Single pass, chunked
147
- let mut sum_ame = vec![0.0; k];
148
- let mut partial_jl_sums = vec![0.0; k*k];
149
-
150
- let mut idx_start = 0;
151
- while idx_start < n {
152
- let idx_end = (idx_start + chunk).min(n);
153
- let x_chunk = X_view.slice(s![idx_start..idx_end, ..]);
154
- let z_chunk = x_chunk.dot(&beta); // shape(n_chunk)
155
-
156
- // pdf for continuous partials
157
- let pdf_chunk = z_chunk.mapv(|z| pdf_logit_probit(is_logit, z));
158
-
159
- // handle discrete
160
- for &jj in &disc_idx {
161
- let b_j = beta[jj];
162
- let col_j = x_chunk.column(jj);
163
- // z_j1 => z_chunk + (1-col_j)*b_j
164
- // z_j0 => z_chunk - col_j*b_j
165
- let delta_j1 = (1.0 - &col_j).mapv(|x| x*b_j);
166
- let delta_j0 = col_j.mapv(|x| -x*b_j);
167
- let z_j1 = &z_chunk + &delta_j1;
168
- let z_j0 = &z_chunk + &delta_j0;
169
-
170
- let cdf_j1 = z_j1.mapv(|z| cdf_logit_probit(is_logit, z));
171
- let cdf_j0 = z_j0.mapv(|z| cdf_logit_probit(is_logit, z));
172
- let eff = cdf_j1.sum() - cdf_j0.sum();
173
- sum_ame[jj]+= eff;
174
-
175
- let pdf_j1 = z_j1.mapv(|z| pdf_logit_probit(is_logit, z));
176
- let pdf_j0 = z_j0.mapv(|z| pdf_logit_probit(is_logit, z));
177
- for l in 0..k {
178
- let grad = if l==jj {
179
- pdf_j1.sum()
180
- } else {
181
- let col_l = x_chunk.column(l);
182
- (&pdf_j1 - &pdf_j0).dot(&col_l)
183
- };
184
- partial_jl_sums[jj*k + l]+= grad;
185
- }
186
- }
187
-
188
- // handle continuous
189
- for j in 0..k {
190
- if intercept_idx.contains(&j) || disc_idx.contains(&j) {
191
- continue;
192
- }
193
- let b_j = beta[j];
194
- sum_ame[j]+= b_j* pdf_chunk.sum();
195
-
196
- for l in 0..k {
197
- let grad = if j==l {
198
- pdf_chunk.sum()
199
- } else {
200
- let col_l = x_chunk.column(l);
201
- // - b_j * sum( z_chunk * col_l * pdf_chunk)
202
- -b_j * (&z_chunk * &col_l).dot(&pdf_chunk)
203
- };
204
- partial_jl_sums[j*k + l]+= grad;
205
- }
206
- }
207
-
208
- idx_start= idx_end;
209
- }
210
-
211
- // 9) Average
212
- let nf = n as f64;
213
- let ame_vals: Vec<f64> = sum_ame.iter().map(|&v| v/nf).collect();
214
-
215
- let mut grad_ame = ndarray::Array2::<f64>::zeros((k,k));
216
- for j in 0..k {
217
- for l in 0..k {
218
- grad_ame[[j,l]] = partial_jl_sums[j*k + l]/ nf;
219
- }
220
- }
221
-
222
- // cov(ame)
223
- let j_c = grad_ame.dot(&cov_beta);
224
- let cov_ame = j_c.dot(&grad_ame.t());
225
- let normal = Normal::new(0.0, 1.0).unwrap();
226
- let var_ame: Vec<f64> = cov_ame.diag().iter().map(|&x| x.max(0.0)).collect();
227
- let se_vals: Vec<f64> = var_ame.iter().map(|&x| x.sqrt()).collect();
228
-
229
- // 10) Build final output => skip intercept
230
- let mut dy_dx = Vec::new();
231
- let mut se_err= Vec::new();
232
- let mut z_vals= Vec::new();
233
- let mut p_vals= Vec::new();
234
- let mut sigs= Vec::new();
235
- let mut names_out= Vec::new();
236
-
237
- for j in 0..k {
238
- if intercept_idx.contains(&j) {
239
- continue;
240
- }
241
- let dy = ame_vals[j];
242
- let s = se_vals[j];
243
- dy_dx.push(dy);
244
- se_err.push(s);
245
- if s>1e-14 {
246
- let z = dy/s;
247
- let p = 2.0*(1.0 - normal.cdf(z.abs()));
248
- z_vals.push(z);
249
- p_vals.push(p);
250
- sigs.push(add_significance_stars(p));
251
- } else {
252
- z_vals.push(0.0);
253
- p_vals.push(1.0);
254
- sigs.push("");
255
- }
256
- names_out.push(exog_names[j].clone());
257
- }
258
-
259
- // 11) Build a Pandas DataFrame
260
- let pd = py.import("pandas")?;
261
- let data = PyDict::new(py);
262
- data.set_item("dy/dx", &dy_dx)?;
263
- data.set_item("Std. Err", &se_err)?;
264
- data.set_item("z", &z_vals)?;
265
- data.set_item("Pr(>|z|)", &p_vals)?;
266
- data.set_item("Significance", &sigs)?;
267
-
268
- let kwargs = PyDict::new(py);
269
- kwargs.set_item("data", data)?;
270
- kwargs.set_item("index", names_out)?;
271
-
272
- let df = pd.call_method("DataFrame", (), Some(kwargs))?;
273
- Ok(df)
274
- }
275
-
276
- /// Python module
277
- #[pymodule]
278
- fn febolt(_py: Python, m: &PyModule) -> PyResult<()> {
279
- m.add_function(wrap_pyfunction!(ame, m)?)?;
280
- Ok(())
281
- }
File without changes
File without changes
File without changes