febolt 0.1.56__tar.gz → 0.1.57__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {febolt-0.1.56 → febolt-0.1.57}/Cargo.lock +5 -5
- {febolt-0.1.56 → febolt-0.1.57}/Cargo.toml +1 -1
- {febolt-0.1.56 → febolt-0.1.57}/PKG-INFO +1 -1
- {febolt-0.1.56 → febolt-0.1.57}/pyproject.toml +1 -1
- febolt-0.1.57/src/lib.rs +296 -0
- febolt-0.1.56/src/lib.rs +0 -281
- {febolt-0.1.56 → febolt-0.1.57}/.github/workflows/CI.yml +0 -0
- {febolt-0.1.56 → febolt-0.1.57}/.gitignore +0 -0
- {febolt-0.1.56 → febolt-0.1.57}/README.md +0 -0
- {febolt-0.1.56 → febolt-0.1.57}/build.rs +0 -0
@@ -362,7 +362,7 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
|
362
362
|
|
363
363
|
[[package]]
|
364
364
|
name = "febolt"
|
365
|
-
version = "0.1.
|
365
|
+
version = "0.1.57"
|
366
366
|
dependencies = [
|
367
367
|
"blas",
|
368
368
|
"intel-mkl-src",
|
@@ -1046,9 +1046,9 @@ dependencies = [
|
|
1046
1046
|
|
1047
1047
|
[[package]]
|
1048
1048
|
name = "openssl"
|
1049
|
-
version = "0.10.
|
1049
|
+
version = "0.10.71"
|
1050
1050
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1051
|
-
checksum = "
|
1051
|
+
checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
|
1052
1052
|
dependencies = [
|
1053
1053
|
"bitflags",
|
1054
1054
|
"cfg-if",
|
@@ -1087,9 +1087,9 @@ dependencies = [
|
|
1087
1087
|
|
1088
1088
|
[[package]]
|
1089
1089
|
name = "openssl-sys"
|
1090
|
-
version = "0.9.
|
1090
|
+
version = "0.9.106"
|
1091
1091
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1092
|
-
checksum = "
|
1092
|
+
checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
|
1093
1093
|
dependencies = [
|
1094
1094
|
"cc",
|
1095
1095
|
"libc",
|
@@ -4,7 +4,7 @@ build-backend = "maturin"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "febolt"
|
7
|
-
version = "0.1.
|
7
|
+
version = "0.1.57"
|
8
8
|
requires-python = ">=3.8"
|
9
9
|
description = "A Rust-based Statistics and ML package, callable from Python."
|
10
10
|
keywords = ["rust", "python", "Machine Learning", "Statistics", "pyo3"]
|
febolt-0.1.57/src/lib.rs
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
// src/lib.rs
|
2
|
+
|
3
|
+
use pyo3::prelude::*;
|
4
|
+
use pyo3::types::{IntoPyDict, PyAny, PyDict, PyList, PyType};
|
5
|
+
use pyo3::wrap_pyfunction;
|
6
|
+
|
7
|
+
use numpy::{PyArray1, PyArray2, IntoPyArray};
|
8
|
+
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s, Axis};
|
9
|
+
use statrs::distribution::{Continuous, Normal, ContinuousCDF};
|
10
|
+
use std::collections::HashMap;
|
11
|
+
use std::fmt::Write as _; // for building cluster-keys
|
12
|
+
|
13
|
+
/// Add significance stars.
|
14
|
+
fn add_significance_stars(p: f64) -> &'static str {
|
15
|
+
if p < 0.01 {
|
16
|
+
"***"
|
17
|
+
} else if p < 0.05 {
|
18
|
+
"**"
|
19
|
+
} else if p < 0.1 {
|
20
|
+
"*"
|
21
|
+
} else {
|
22
|
+
""
|
23
|
+
}
|
24
|
+
}
|
25
|
+
|
26
|
+
/// A tiny helper to check if we have a Logit or Probit from statsmodels
|
27
|
+
fn detect_model_type(model: &PyAny) -> Result<bool, PyErr> {
|
28
|
+
// Return is_logit: true if logit, false if probit
|
29
|
+
let model_obj = model.getattr("model").unwrap_or(model);
|
30
|
+
let cls: String = model_obj
|
31
|
+
.getattr("__class__")?
|
32
|
+
.getattr("__name__")?
|
33
|
+
.extract()?;
|
34
|
+
let lc = cls.to_lowercase();
|
35
|
+
if lc == "logit" {
|
36
|
+
Ok(true) // is_logit = true
|
37
|
+
} else if lc == "probit" {
|
38
|
+
Ok(false) // is_logit = false
|
39
|
+
} else {
|
40
|
+
Err(pyo3::exceptions::PyValueError::new_err(
|
41
|
+
format!("ame: only Logit or Probit supported, got {cls}"),
|
42
|
+
))
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
/// Downcast a Python object to NumPy PyArray2<f64> => ndarray::ArrayView2<f64>.
|
47
|
+
///
|
48
|
+
/// Marked unsafe because .as_array() in pyo3-numpy is unsafe, trusting Python memory is valid.
|
49
|
+
fn as_array2_f64<'py>(obj: &'py PyAny) -> PyResult<ArrayView2<'py, f64>> {
|
50
|
+
let pyarray = obj.downcast::<PyArray2<f64>>()?;
|
51
|
+
let view = unsafe { pyarray.as_array() };
|
52
|
+
Ok(view)
|
53
|
+
}
|
54
|
+
|
55
|
+
/// Similarly, for 1D arrays.
|
56
|
+
fn as_array1_f64<'py>(obj: &'py PyAny) -> PyResult<ArrayView1<'py, f64>> {
|
57
|
+
let pyarray = obj.downcast::<PyArray1<f64>>()?;
|
58
|
+
let view = unsafe { pyarray.as_array() };
|
59
|
+
Ok(view)
|
60
|
+
}
|
61
|
+
|
62
|
+
/// Evaluate logistic or normal cdf
|
63
|
+
fn cdf_logit_probit(is_logit: bool, z: f64) -> f64 {
|
64
|
+
if is_logit {
|
65
|
+
// logistic cdf
|
66
|
+
1.0 / (1.0 + (-z).exp())
|
67
|
+
} else {
|
68
|
+
// normal cdf
|
69
|
+
let dist = Normal::new(0.0, 1.0).unwrap();
|
70
|
+
dist.cdf(z)
|
71
|
+
}
|
72
|
+
}
|
73
|
+
/// Evaluate logistic or normal pdf
|
74
|
+
fn pdf_logit_probit(is_logit: bool, z: f64) -> f64 {
|
75
|
+
if is_logit {
|
76
|
+
// logistic pdf
|
77
|
+
let e = z.exp();
|
78
|
+
e / (1.0 + e).powi(2)
|
79
|
+
} else {
|
80
|
+
// normal pdf
|
81
|
+
(-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt()
|
82
|
+
}
|
83
|
+
}
|
84
|
+
|
85
|
+
/// The main function that calculates Logit/Probit AME in the same style as the original Probit code
|
86
|
+
#[pyfunction]
|
87
|
+
fn ame<'py>(
|
88
|
+
py: Python<'py>,
|
89
|
+
model: &'py PyAny, // Could be Logit or Probit
|
90
|
+
chunk_size: Option<usize>, // optional chunk
|
91
|
+
) -> PyResult<&'py PyAny> {
|
92
|
+
// 1) detect Logit vs Probit
|
93
|
+
let is_logit = detect_model_type(model)?;
|
94
|
+
|
95
|
+
// 2) read params
|
96
|
+
let params_pyarray: &PyArray1<f64> = model.getattr("params")?.downcast()?;
|
97
|
+
let beta = unsafe { params_pyarray.as_array() }; // shape(k)
|
98
|
+
|
99
|
+
// 3) read cov
|
100
|
+
let cov_pyarray: &PyArray2<f64> = model.call_method0("cov_params")?.downcast()?;
|
101
|
+
let cov_beta = unsafe { cov_pyarray.as_array() }; // shape(k,k)
|
102
|
+
|
103
|
+
// 4) Get model object and handle exog (X) and exog_names
|
104
|
+
let model_obj = model.getattr("model").unwrap_or(model);
|
105
|
+
|
106
|
+
// Handle pandas DataFrame input
|
107
|
+
let exog_py = model_obj.getattr("exog")?;
|
108
|
+
let (x_pyarray, exog_names) = if let Ok(values) = exog_py.getattr("values") {
|
109
|
+
// Pandas DataFrame path
|
110
|
+
(
|
111
|
+
values.downcast::<PyArray2<f64>>()?,
|
112
|
+
exog_py.getattr("columns")?.extract::<Vec<String>>()?
|
113
|
+
)
|
114
|
+
} else {
|
115
|
+
// Numpy array path
|
116
|
+
(
|
117
|
+
exog_py.downcast::<PyArray2<f64>>()?,
|
118
|
+
model_obj.getattr("exog_names")?.extract::<Vec<String>>()?
|
119
|
+
)
|
120
|
+
};
|
121
|
+
|
122
|
+
let X = unsafe { x_pyarray.as_array() };
|
123
|
+
let (n, k) = (X.nrows(), X.ncols());
|
124
|
+
let chunk = chunk_size.unwrap_or(n);
|
125
|
+
|
126
|
+
// 5) Identify intercept columns
|
127
|
+
let intercept_indices: Vec<usize> = exog_names
|
128
|
+
.iter()
|
129
|
+
.enumerate()
|
130
|
+
.filter_map(|(i, nm)| {
|
131
|
+
let ln = nm.to_lowercase();
|
132
|
+
if ln == "const" || ln == "intercept" {
|
133
|
+
Some(i)
|
134
|
+
} else {
|
135
|
+
None
|
136
|
+
}
|
137
|
+
})
|
138
|
+
.collect();
|
139
|
+
|
140
|
+
// 6) Identify discrete columns => strictly 0/1
|
141
|
+
let is_discrete: Vec<usize> = exog_names
|
142
|
+
.iter()
|
143
|
+
.enumerate()
|
144
|
+
.filter_map(|(j, _)| {
|
145
|
+
if intercept_indices.contains(&j) {
|
146
|
+
None
|
147
|
+
} else {
|
148
|
+
let col_j = X.column(j);
|
149
|
+
if col_j.iter().all(|&v| v == 0.0 || v == 1.0) {
|
150
|
+
Some(j)
|
151
|
+
} else {
|
152
|
+
None
|
153
|
+
}
|
154
|
+
}
|
155
|
+
})
|
156
|
+
.collect();
|
157
|
+
|
158
|
+
// 7) Prepare accumulators
|
159
|
+
let mut sum_ame = vec![0.0; k]; // sum partial effects
|
160
|
+
let mut partial_jl_sums = vec![0.0; k * k];
|
161
|
+
let normal = Normal::new(0.0, 1.0).unwrap();
|
162
|
+
|
163
|
+
// 8) single pass with chunk
|
164
|
+
let mut idx_start = 0;
|
165
|
+
while idx_start < n {
|
166
|
+
let idx_end = (idx_start + chunk).min(n);
|
167
|
+
let x_chunk = X.slice(s![idx_start..idx_end, ..]);
|
168
|
+
let z_chunk = x_chunk.dot(&beta); // shape(n_chunk)
|
169
|
+
|
170
|
+
// pdf => we might do partial for continuous
|
171
|
+
let pdf_chunk = z_chunk.mapv(|z| pdf_logit_probit(is_logit, z));
|
172
|
+
|
173
|
+
// handle discrete
|
174
|
+
for &j in &is_discrete {
|
175
|
+
let xj_col = x_chunk.column(j);
|
176
|
+
let b_j = beta[j];
|
177
|
+
// z_j1 => z + (1-xj)*b_j
|
178
|
+
// z_j0 => z - xj*b_j
|
179
|
+
let delta_j1 = (1.0 - &xj_col).mapv(|x| x * b_j);
|
180
|
+
let delta_j0 = xj_col.mapv(|x| -x * b_j);
|
181
|
+
let z_j1 = &z_chunk + &delta_j1;
|
182
|
+
let z_j0 = &z_chunk + &delta_j0;
|
183
|
+
|
184
|
+
let cdf_j1 = z_j1.mapv(|z| cdf_logit_probit(is_logit, z));
|
185
|
+
let cdf_j0 = z_j0.mapv(|z| cdf_logit_probit(is_logit, z));
|
186
|
+
// sum
|
187
|
+
let effect_sum = cdf_j1.sum() - cdf_j0.sum();
|
188
|
+
sum_ame[j] += effect_sum;
|
189
|
+
|
190
|
+
// partial_jl_sums => row j, col l
|
191
|
+
let pdf_j1 = z_j1.mapv(|z| pdf_logit_probit(is_logit, z));
|
192
|
+
let pdf_j0 = z_j0.mapv(|z| pdf_logit_probit(is_logit, z));
|
193
|
+
for l in 0..k {
|
194
|
+
let grad = if l == j {
|
195
|
+
// special case
|
196
|
+
pdf_j1.sum()
|
197
|
+
} else {
|
198
|
+
let x_l = x_chunk.column(l);
|
199
|
+
let diff_pdf = &pdf_j1 - &pdf_j0;
|
200
|
+
diff_pdf.dot(&x_l)
|
201
|
+
};
|
202
|
+
partial_jl_sums[j * k + l] += grad;
|
203
|
+
}
|
204
|
+
}
|
205
|
+
|
206
|
+
// handle continuous
|
207
|
+
for j in 0..k {
|
208
|
+
if intercept_indices.contains(&j) || is_discrete.contains(&j) {
|
209
|
+
continue;
|
210
|
+
}
|
211
|
+
let b_j = beta[j];
|
212
|
+
// sum_ame
|
213
|
+
sum_ame[j] += b_j * pdf_chunk.sum();
|
214
|
+
// partial_jl_sums => row j, col l
|
215
|
+
for l in 0..k {
|
216
|
+
let grad = if j == l {
|
217
|
+
pdf_chunk.sum()
|
218
|
+
} else {
|
219
|
+
// - b_j * sum(z_chunk * x_col(l) * pdf_chunk)
|
220
|
+
let x_l = x_chunk.column(l);
|
221
|
+
// careful about sign from the original code
|
222
|
+
-b_j * (&z_chunk * &x_l).dot(&pdf_chunk)
|
223
|
+
};
|
224
|
+
partial_jl_sums[j * k + l] += grad;
|
225
|
+
}
|
226
|
+
}
|
227
|
+
|
228
|
+
idx_start = idx_end;
|
229
|
+
}
|
230
|
+
|
231
|
+
// 9) average sums
|
232
|
+
let ame: Vec<f64> = sum_ame.iter().map(|v| v / (n as f64)).collect();
|
233
|
+
|
234
|
+
// gradient matrix => shape(k,k)
|
235
|
+
let mut grad_ame = Array2::<f64>::zeros((k,k));
|
236
|
+
for j in 0..k {
|
237
|
+
for l in 0..k {
|
238
|
+
grad_ame[[j,l]] = partial_jl_sums[j * k + l] / (n as f64);
|
239
|
+
}
|
240
|
+
}
|
241
|
+
|
242
|
+
// cov => grad_ame * cov_beta * grad_ame^T
|
243
|
+
let cov_ame = grad_ame.dot(&cov_beta).dot(&grad_ame.t());
|
244
|
+
let var_ame: Vec<f64> = cov_ame.diag().iter().map(|&v| v.max(0.0)).collect();
|
245
|
+
let se_ame: Vec<f64> = var_ame.iter().map(|&v| v.sqrt()).collect();
|
246
|
+
|
247
|
+
// 10) Build final results
|
248
|
+
let (mut dy_dx, mut se_err, mut z_vals, mut p_vals, mut sig) =
|
249
|
+
(Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new());
|
250
|
+
let mut names_out = Vec::new();
|
251
|
+
|
252
|
+
for j in 0..k {
|
253
|
+
if intercept_indices.contains(&j) {
|
254
|
+
continue;
|
255
|
+
}
|
256
|
+
let dy = ame[j];
|
257
|
+
let s = se_ame[j];
|
258
|
+
dy_dx.push(dy);
|
259
|
+
se_err.push(s);
|
260
|
+
if s > 1e-15 {
|
261
|
+
let z = dy / s;
|
262
|
+
let p = 2.0*(1.0 - normal.cdf(z.abs()));
|
263
|
+
z_vals.push(z);
|
264
|
+
p_vals.push(p);
|
265
|
+
sig.push(add_significance_stars(p));
|
266
|
+
} else {
|
267
|
+
z_vals.push(0.0);
|
268
|
+
p_vals.push(1.0);
|
269
|
+
sig.push("");
|
270
|
+
}
|
271
|
+
names_out.push(exog_names[j].clone());
|
272
|
+
}
|
273
|
+
|
274
|
+
// 11) Create DataFrame
|
275
|
+
let pd = py.import("pandas")?;
|
276
|
+
let data = PyDict::new(py);
|
277
|
+
data.set_item("dy/dx", &dy_dx)?;
|
278
|
+
data.set_item("Std. Err", &se_err)?;
|
279
|
+
data.set_item("z", &z_vals)?;
|
280
|
+
data.set_item("Pr(>|z|)", &p_vals)?;
|
281
|
+
data.set_item("Significance", &sig)?;
|
282
|
+
|
283
|
+
let kwargs = PyDict::new(py);
|
284
|
+
kwargs.set_item("data", data)?;
|
285
|
+
kwargs.set_item("index", &names_out)?;
|
286
|
+
|
287
|
+
let df = pd.call_method("DataFrame", (), Some(kwargs))?;
|
288
|
+
Ok(df)
|
289
|
+
}
|
290
|
+
|
291
|
+
|
292
|
+
#[pymodule]
|
293
|
+
fn febolt(_py: Python, m: &PyModule) -> PyResult<()> {
|
294
|
+
m.add_function(wrap_pyfunction!(ame, m)?)?;
|
295
|
+
Ok(())
|
296
|
+
}
|
febolt-0.1.56/src/lib.rs
DELETED
@@ -1,281 +0,0 @@
|
|
1
|
-
use pyo3::prelude::*;
|
2
|
-
use pyo3::types::{PyAny, PyDict};
|
3
|
-
use pyo3::wrap_pyfunction;
|
4
|
-
|
5
|
-
use numpy::{PyArray1, PyArray2};
|
6
|
-
use ndarray::{Array1, Array2, Axis, s};
|
7
|
-
use statrs::distribution::{Normal, ContinuousCDF};
|
8
|
-
use std::f64;
|
9
|
-
|
10
|
-
/// Distinguish whether statsmodels model is Logit or Probit
|
11
|
-
fn is_logit_model(model: &PyAny) -> PyResult<bool> {
|
12
|
-
// Return true if Logit, false if Probit
|
13
|
-
let model_obj = model.getattr("model").unwrap_or(model);
|
14
|
-
let cls_name: String = model_obj.getattr("__class__")?.getattr("__name__")?.extract()?;
|
15
|
-
let lower = cls_name.to_lowercase();
|
16
|
-
if lower == "logit" {
|
17
|
-
Ok(true)
|
18
|
-
} else if lower == "probit" {
|
19
|
-
Ok(false)
|
20
|
-
} else {
|
21
|
-
Err(pyo3::exceptions::PyValueError::new_err(
|
22
|
-
format!("ame() only supports Logit or Probit. Got: {cls_name}"),
|
23
|
-
))
|
24
|
-
}
|
25
|
-
}
|
26
|
-
|
27
|
-
/// Evaluate logistic or normal cdf
|
28
|
-
fn cdf_logit_probit(is_logit: bool, z: f64) -> f64 {
|
29
|
-
if is_logit {
|
30
|
-
// logistic cdf => 1/(1+exp(-z))
|
31
|
-
1.0 / (1.0 + (-z).exp())
|
32
|
-
} else {
|
33
|
-
// normal cdf
|
34
|
-
let dist = Normal::new(0.0, 1.0).unwrap();
|
35
|
-
dist.cdf(z)
|
36
|
-
}
|
37
|
-
}
|
38
|
-
|
39
|
-
/// Evaluate logistic or normal pdf
|
40
|
-
fn pdf_logit_probit(is_logit: bool, z: f64) -> f64 {
|
41
|
-
if is_logit {
|
42
|
-
let e = z.exp();
|
43
|
-
e / (1.0 + e).powi(2)
|
44
|
-
} else {
|
45
|
-
(-0.5 * z * z).exp() / (2.0 * f64::consts::PI).sqrt()
|
46
|
-
}
|
47
|
-
}
|
48
|
-
|
49
|
-
/// Significance stars
|
50
|
-
fn add_significance_stars(p: f64) -> &'static str {
|
51
|
-
if p < 0.001 {
|
52
|
-
"***"
|
53
|
-
} else if p < 0.01 {
|
54
|
-
"**"
|
55
|
-
} else if p < 0.05 {
|
56
|
-
"*"
|
57
|
-
} else if p < 0.1 {
|
58
|
-
"."
|
59
|
-
} else {
|
60
|
-
""
|
61
|
-
}
|
62
|
-
}
|
63
|
-
|
64
|
-
/// Main AME function
|
65
|
-
#[pyfunction]
|
66
|
-
fn ame<'py>(
|
67
|
-
py: Python<'py>,
|
68
|
-
model: &'py PyAny, // statsmodels fitted results
|
69
|
-
chunk_size: Option<usize>,
|
70
|
-
) -> PyResult<&'py PyAny> {
|
71
|
-
// 1) Detect Logit vs Probit
|
72
|
-
let is_logit = is_logit_model(model)?;
|
73
|
-
|
74
|
-
// 2) Extract 1D params => shape(k,)
|
75
|
-
let params_obj = model.getattr("params")?;
|
76
|
-
// Usually `params` is already a 1D numpy array, so:
|
77
|
-
let params_nd = params_obj.downcast::<PyArray1<f64>>()?;
|
78
|
-
let beta_view = unsafe { params_nd.as_array() };
|
79
|
-
let beta = beta_view.to_owned(); // shape(k)
|
80
|
-
|
81
|
-
let k = beta.len();
|
82
|
-
|
83
|
-
// 3) Extract covariance => shape(k,k)
|
84
|
-
let cov_py = model.call_method0("cov_params")?.downcast::<PyArray2<f64>>()?;
|
85
|
-
let cov_view = unsafe { cov_py.as_array() };
|
86
|
-
if cov_view.nrows() != k || cov_view.ncols() != k {
|
87
|
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
88
|
-
"cov_params dimension mismatch with `params`"
|
89
|
-
));
|
90
|
-
}
|
91
|
-
let cov_beta = cov_view.to_owned();
|
92
|
-
|
93
|
-
// 4) Extract exog_names from statsmodels
|
94
|
-
let model_obj = model.getattr("model")?;
|
95
|
-
let exog_names: Vec<String> = model_obj.getattr("exog_names")?.extract()?;
|
96
|
-
|
97
|
-
// 5) Extract exog
|
98
|
-
// In many cases, `model_obj.getattr("exog")` is already a NumPy array.
|
99
|
-
// But if it's a DataFrame, we forcibly do .to_numpy().
|
100
|
-
let exog_obj = model_obj.getattr("exog")?;
|
101
|
-
// We'll try direct downcast first:
|
102
|
-
let exog_pyarray = match exog_obj.downcast::<PyArray2<f64>>() {
|
103
|
-
Ok(arr) => arr,
|
104
|
-
Err(_e) => {
|
105
|
-
// fallback => call .to_numpy()
|
106
|
-
let np = exog_obj.call_method0("to_numpy")?;
|
107
|
-
np.downcast::<PyArray2<f64>>()?
|
108
|
-
}
|
109
|
-
};
|
110
|
-
let X_view = unsafe { exog_pyarray.as_array() };
|
111
|
-
let (n, k_exog) = X_view.dim();
|
112
|
-
if k_exog != k {
|
113
|
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
114
|
-
format!("exog dimension mismatch: exog has {} columns, params has length={}", k_exog, k)
|
115
|
-
));
|
116
|
-
}
|
117
|
-
|
118
|
-
let chunk = chunk_size.unwrap_or(n);
|
119
|
-
|
120
|
-
// 6) Identify intercept columns
|
121
|
-
let intercept_idx: Vec<usize> = exog_names
|
122
|
-
.iter()
|
123
|
-
.enumerate()
|
124
|
-
.filter_map(|(i, nm)| {
|
125
|
-
let ln = nm.to_lowercase();
|
126
|
-
if ln=="intercept" || ln=="const" {
|
127
|
-
Some(i)
|
128
|
-
} else {
|
129
|
-
None
|
130
|
-
}
|
131
|
-
})
|
132
|
-
.collect();
|
133
|
-
|
134
|
-
// 7) Identify discrete columns => strictly 0/1
|
135
|
-
let mut disc_idx = Vec::new();
|
136
|
-
for j in 0..k {
|
137
|
-
if intercept_idx.contains(&j) {
|
138
|
-
continue;
|
139
|
-
}
|
140
|
-
let col_j = X_view.column(j);
|
141
|
-
if col_j.iter().all(|&v| v==0.0 || v==1.0) {
|
142
|
-
disc_idx.push(j);
|
143
|
-
}
|
144
|
-
}
|
145
|
-
|
146
|
-
// 8) Single pass, chunked
|
147
|
-
let mut sum_ame = vec![0.0; k];
|
148
|
-
let mut partial_jl_sums = vec![0.0; k*k];
|
149
|
-
|
150
|
-
let mut idx_start = 0;
|
151
|
-
while idx_start < n {
|
152
|
-
let idx_end = (idx_start + chunk).min(n);
|
153
|
-
let x_chunk = X_view.slice(s![idx_start..idx_end, ..]);
|
154
|
-
let z_chunk = x_chunk.dot(&beta); // shape(n_chunk)
|
155
|
-
|
156
|
-
// pdf for continuous partials
|
157
|
-
let pdf_chunk = z_chunk.mapv(|z| pdf_logit_probit(is_logit, z));
|
158
|
-
|
159
|
-
// handle discrete
|
160
|
-
for &jj in &disc_idx {
|
161
|
-
let b_j = beta[jj];
|
162
|
-
let col_j = x_chunk.column(jj);
|
163
|
-
// z_j1 => z_chunk + (1-col_j)*b_j
|
164
|
-
// z_j0 => z_chunk - col_j*b_j
|
165
|
-
let delta_j1 = (1.0 - &col_j).mapv(|x| x*b_j);
|
166
|
-
let delta_j0 = col_j.mapv(|x| -x*b_j);
|
167
|
-
let z_j1 = &z_chunk + &delta_j1;
|
168
|
-
let z_j0 = &z_chunk + &delta_j0;
|
169
|
-
|
170
|
-
let cdf_j1 = z_j1.mapv(|z| cdf_logit_probit(is_logit, z));
|
171
|
-
let cdf_j0 = z_j0.mapv(|z| cdf_logit_probit(is_logit, z));
|
172
|
-
let eff = cdf_j1.sum() - cdf_j0.sum();
|
173
|
-
sum_ame[jj]+= eff;
|
174
|
-
|
175
|
-
let pdf_j1 = z_j1.mapv(|z| pdf_logit_probit(is_logit, z));
|
176
|
-
let pdf_j0 = z_j0.mapv(|z| pdf_logit_probit(is_logit, z));
|
177
|
-
for l in 0..k {
|
178
|
-
let grad = if l==jj {
|
179
|
-
pdf_j1.sum()
|
180
|
-
} else {
|
181
|
-
let col_l = x_chunk.column(l);
|
182
|
-
(&pdf_j1 - &pdf_j0).dot(&col_l)
|
183
|
-
};
|
184
|
-
partial_jl_sums[jj*k + l]+= grad;
|
185
|
-
}
|
186
|
-
}
|
187
|
-
|
188
|
-
// handle continuous
|
189
|
-
for j in 0..k {
|
190
|
-
if intercept_idx.contains(&j) || disc_idx.contains(&j) {
|
191
|
-
continue;
|
192
|
-
}
|
193
|
-
let b_j = beta[j];
|
194
|
-
sum_ame[j]+= b_j* pdf_chunk.sum();
|
195
|
-
|
196
|
-
for l in 0..k {
|
197
|
-
let grad = if j==l {
|
198
|
-
pdf_chunk.sum()
|
199
|
-
} else {
|
200
|
-
let col_l = x_chunk.column(l);
|
201
|
-
// - b_j * sum( z_chunk * col_l * pdf_chunk)
|
202
|
-
-b_j * (&z_chunk * &col_l).dot(&pdf_chunk)
|
203
|
-
};
|
204
|
-
partial_jl_sums[j*k + l]+= grad;
|
205
|
-
}
|
206
|
-
}
|
207
|
-
|
208
|
-
idx_start= idx_end;
|
209
|
-
}
|
210
|
-
|
211
|
-
// 9) Average
|
212
|
-
let nf = n as f64;
|
213
|
-
let ame_vals: Vec<f64> = sum_ame.iter().map(|&v| v/nf).collect();
|
214
|
-
|
215
|
-
let mut grad_ame = ndarray::Array2::<f64>::zeros((k,k));
|
216
|
-
for j in 0..k {
|
217
|
-
for l in 0..k {
|
218
|
-
grad_ame[[j,l]] = partial_jl_sums[j*k + l]/ nf;
|
219
|
-
}
|
220
|
-
}
|
221
|
-
|
222
|
-
// cov(ame)
|
223
|
-
let j_c = grad_ame.dot(&cov_beta);
|
224
|
-
let cov_ame = j_c.dot(&grad_ame.t());
|
225
|
-
let normal = Normal::new(0.0, 1.0).unwrap();
|
226
|
-
let var_ame: Vec<f64> = cov_ame.diag().iter().map(|&x| x.max(0.0)).collect();
|
227
|
-
let se_vals: Vec<f64> = var_ame.iter().map(|&x| x.sqrt()).collect();
|
228
|
-
|
229
|
-
// 10) Build final output => skip intercept
|
230
|
-
let mut dy_dx = Vec::new();
|
231
|
-
let mut se_err= Vec::new();
|
232
|
-
let mut z_vals= Vec::new();
|
233
|
-
let mut p_vals= Vec::new();
|
234
|
-
let mut sigs= Vec::new();
|
235
|
-
let mut names_out= Vec::new();
|
236
|
-
|
237
|
-
for j in 0..k {
|
238
|
-
if intercept_idx.contains(&j) {
|
239
|
-
continue;
|
240
|
-
}
|
241
|
-
let dy = ame_vals[j];
|
242
|
-
let s = se_vals[j];
|
243
|
-
dy_dx.push(dy);
|
244
|
-
se_err.push(s);
|
245
|
-
if s>1e-14 {
|
246
|
-
let z = dy/s;
|
247
|
-
let p = 2.0*(1.0 - normal.cdf(z.abs()));
|
248
|
-
z_vals.push(z);
|
249
|
-
p_vals.push(p);
|
250
|
-
sigs.push(add_significance_stars(p));
|
251
|
-
} else {
|
252
|
-
z_vals.push(0.0);
|
253
|
-
p_vals.push(1.0);
|
254
|
-
sigs.push("");
|
255
|
-
}
|
256
|
-
names_out.push(exog_names[j].clone());
|
257
|
-
}
|
258
|
-
|
259
|
-
// 11) Build a Pandas DataFrame
|
260
|
-
let pd = py.import("pandas")?;
|
261
|
-
let data = PyDict::new(py);
|
262
|
-
data.set_item("dy/dx", &dy_dx)?;
|
263
|
-
data.set_item("Std. Err", &se_err)?;
|
264
|
-
data.set_item("z", &z_vals)?;
|
265
|
-
data.set_item("Pr(>|z|)", &p_vals)?;
|
266
|
-
data.set_item("Significance", &sigs)?;
|
267
|
-
|
268
|
-
let kwargs = PyDict::new(py);
|
269
|
-
kwargs.set_item("data", data)?;
|
270
|
-
kwargs.set_item("index", names_out)?;
|
271
|
-
|
272
|
-
let df = pd.call_method("DataFrame", (), Some(kwargs))?;
|
273
|
-
Ok(df)
|
274
|
-
}
|
275
|
-
|
276
|
-
/// Python module
|
277
|
-
#[pymodule]
|
278
|
-
fn febolt(_py: Python, m: &PyModule) -> PyResult<()> {
|
279
|
-
m.add_function(wrap_pyfunction!(ame, m)?)?;
|
280
|
-
Ok(())
|
281
|
-
}
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|