SearchLibrium 0.0.72__tar.gz → 0.0.84__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/PKG-INFO +1 -1
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/pyproject.toml +1 -1
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/__init__.py +15 -7
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/search.py +425 -0
- searchlibrium-0.0.84/src/SearchLibrium/selection_models.py +268 -0
- searchlibrium-0.0.84/src/SearchLibrium/version.txt +1 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/SOURCES.txt +1 -0
- searchlibrium-0.0.72/src/SearchLibrium/version.txt +0 -1
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/README.md +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/setup.cfg +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/Halton.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/MixedLogit.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/RandomP.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/Two_Level_Nest.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/__main__.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/_choice_model.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/_device.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/bhhh/minimize.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/boxcox_functions.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/call_meta.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/constraints_builder.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/harmony.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/latent_class.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/main.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/main_debug.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/misc.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/mixed_logit.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/mixed_nested.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/mixedrrm.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/multinomial_logit.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/multinomial_nested.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/multinomial_probit.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/ordered_logit.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/rrm.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/setup.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/siman.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium/threshold.py +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/requires.txt +0 -0
- {searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/top_level.txt +0 -0
|
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
|
|
|
59
59
|
realpython = "SearchLibrium.__main__:main"
|
|
60
60
|
|
|
61
61
|
[tool.bumpver]
|
|
62
|
-
current_version = "0.0.
|
|
62
|
+
current_version = "0.0.84"
|
|
63
63
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
64
64
|
commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
|
|
65
65
|
commit = true
|
|
@@ -86,7 +86,9 @@ try:
|
|
|
86
86
|
from .rrm import RandomRegret
|
|
87
87
|
from .mixedrrm import MixedRandomRegret
|
|
88
88
|
from .ordered_logit import OrderedLogit, OrderedLogitLong
|
|
89
|
+
from .selection_models import BinaryProbit, HeckmanTwoStep
|
|
89
90
|
from .latent_class import LatentClassMixedLogit
|
|
91
|
+
from .multinomial_probit import MultinomialProbit
|
|
90
92
|
from .RandomP import RandomParameters
|
|
91
93
|
from .constraints_builder import ConstraintBuilder, create_constraints
|
|
92
94
|
from .search import Parameters
|
|
@@ -102,21 +104,27 @@ except ImportError as e:
|
|
|
102
104
|
from rrm import RandomRegret
|
|
103
105
|
from mixedrrm import MixedRandomRegret
|
|
104
106
|
from ordered_logit import OrderedLogit, OrderedLogitLong
|
|
107
|
+
from selection_models import BinaryProbit, HeckmanTwoStep
|
|
105
108
|
from latent_class import LatentClassMixedLogit
|
|
109
|
+
from multinomial_probit import MultinomialProbit
|
|
106
110
|
from RandomP import RandomParameters
|
|
107
111
|
from constraints_builder import ConstraintBuilder, create_constraints
|
|
108
112
|
from search import Parameters
|
|
109
113
|
from call_meta import call_siman, call_harmony, call_search, estimate_ctrl
|
|
110
114
|
try:
|
|
111
115
|
from .main import print_ascii_art_logo
|
|
112
|
-
except:
|
|
113
|
-
|
|
114
|
-
|
|
116
|
+
except Exception:
|
|
117
|
+
try:
|
|
118
|
+
from main import print_ascii_art_logo
|
|
119
|
+
except Exception:
|
|
120
|
+
print_ascii_art_logo = None
|
|
115
121
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
122
|
+
|
|
123
|
+
if print_ascii_art_logo is not None:
|
|
124
|
+
try:
|
|
125
|
+
print_ascii_art_logo()
|
|
126
|
+
except Exception:
|
|
127
|
+
print("SearchLibrium logo skipped; optional display dependencies are missing.")
|
|
120
128
|
|
|
121
129
|
#print('loaded all')
|
|
122
130
|
print('Welcome to SearchLibrium')
|
|
@@ -1206,8 +1206,413 @@ class Search():
|
|
|
1206
1206
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
1207
1207
|
asvars_new = self.create_dummy_column(self.param.asvarnames)
|
|
1208
1208
|
asvars_new = self.remove_redundant_asvars(asvars_new, self.param.trans_asvars, self.param.asvarnames)
|
|
1209
|
+
|
|
1210
|
+
# Pre-compute pairwise correlations & VIF for collinearity-aware solution generation
|
|
1211
|
+
self._precompute_correlations()
|
|
1209
1212
|
# }
|
|
1210
1213
|
|
|
1214
|
+
''' ---------------------------------------------------------- '''
|
|
1215
|
+
''' Function. Pre-compute pairwise Pearson correlations and '''
|
|
1216
|
+
''' VIF scores for all candidate variables. Called once on '''
|
|
1217
|
+
''' initialisation so that collinearity checks are fast during '''
|
|
1218
|
+
''' the search. '''
|
|
1219
|
+
''' ---------------------------------------------------------- '''
|
|
1220
|
+
def _precompute_correlations(self, corr_threshold=0.90, vif_threshold=10.0):
|
|
1221
|
+
"""
|
|
1222
|
+
Pre-compute Pearson correlation matrix and Variance Inflation Factors
|
|
1223
|
+
(VIF) for all numeric columns in the training dataframe that are also
|
|
1224
|
+
listed in param.varnames.
|
|
1225
|
+
|
|
1226
|
+
Results stored on the instance:
|
|
1227
|
+
self._corr_matrix : pd.DataFrame (variable x variable)
|
|
1228
|
+
self._vif_scores : dict {var: vif_value}
|
|
1229
|
+
self._high_corr_pairs: list of tuples [(var_a, var_b, r), ...]
|
|
1230
|
+
self._corr_threshold : float
|
|
1231
|
+
self._vif_threshold : float
|
|
1232
|
+
"""
|
|
1233
|
+
import pandas as pd
|
|
1234
|
+
|
|
1235
|
+
self._corr_threshold = corr_threshold
|
|
1236
|
+
self._vif_threshold = vif_threshold
|
|
1237
|
+
self._corr_matrix = None
|
|
1238
|
+
self._vif_scores = {}
|
|
1239
|
+
self._high_corr_pairs = []
|
|
1240
|
+
|
|
1241
|
+
try:
|
|
1242
|
+
df = self.param.df
|
|
1243
|
+
candidate_cols = [
|
|
1244
|
+
v for v in self.param.varnames
|
|
1245
|
+
if v in df.columns and pd.api.types.is_numeric_dtype(df[v])
|
|
1246
|
+
]
|
|
1247
|
+
if len(candidate_cols) < 2:
|
|
1248
|
+
return
|
|
1249
|
+
|
|
1250
|
+
X = df[candidate_cols].dropna()
|
|
1251
|
+
|
|
1252
|
+
# ── 1. Pearson correlation matrix ─────────────────────────
|
|
1253
|
+
self._corr_matrix = X.corr()
|
|
1254
|
+
|
|
1255
|
+
# Identify highly correlated pairs (upper triangle only)
|
|
1256
|
+
cols = self._corr_matrix.columns.tolist()
|
|
1257
|
+
for i in range(len(cols)):
|
|
1258
|
+
for j in range(i + 1, len(cols)):
|
|
1259
|
+
r = self._corr_matrix.iloc[i, j]
|
|
1260
|
+
if abs(r) >= corr_threshold:
|
|
1261
|
+
self._high_corr_pairs.append(
|
|
1262
|
+
(cols[i], cols[j], round(float(r), 4))
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
if self._high_corr_pairs:
|
|
1266
|
+
logging.info(
|
|
1267
|
+
"[Collinearity] %d highly correlated pair(s) detected (|r| >= %.2f):",
|
|
1268
|
+
len(self._high_corr_pairs), corr_threshold,
|
|
1269
|
+
)
|
|
1270
|
+
for va, vb, r in self._high_corr_pairs:
|
|
1271
|
+
logging.info(" %s <-> %s r = %.4f", va, vb, r)
|
|
1272
|
+
|
|
1273
|
+
# ── 2. Variance Inflation Factors ─────────────────────────
|
|
1274
|
+
if len(candidate_cols) >= 2:
|
|
1275
|
+
try:
|
|
1276
|
+
from numpy.linalg import lstsq
|
|
1277
|
+
|
|
1278
|
+
Xmat = X.values
|
|
1279
|
+
means = Xmat.mean(axis=0)
|
|
1280
|
+
stds = Xmat.std(axis=0)
|
|
1281
|
+
stds[stds == 0] = 1.0
|
|
1282
|
+
Xz = (Xmat - means) / stds
|
|
1283
|
+
|
|
1284
|
+
for k, col in enumerate(candidate_cols):
|
|
1285
|
+
y_k = Xz[:, k]
|
|
1286
|
+
X_oth = np.delete(Xz, k, axis=1)
|
|
1287
|
+
X_oth = np.column_stack([np.ones(len(y_k)), X_oth])
|
|
1288
|
+
coef, _, _, _ = lstsq(X_oth, y_k, rcond=None)
|
|
1289
|
+
y_hat = X_oth @ coef
|
|
1290
|
+
ss_res = np.sum((y_k - y_hat) ** 2)
|
|
1291
|
+
ss_tot = np.sum((y_k - y_k.mean()) ** 2)
|
|
1292
|
+
r2 = 1.0 - ss_res / ss_tot if ss_tot > 1e-12 else 0.0
|
|
1293
|
+
r2 = min(max(r2, 0.0), 1.0 - 1e-12)
|
|
1294
|
+
self._vif_scores[col] = round(1.0 / (1.0 - r2), 2)
|
|
1295
|
+
|
|
1296
|
+
high_vif = {
|
|
1297
|
+
v: s for v, s in self._vif_scores.items()
|
|
1298
|
+
if s > vif_threshold
|
|
1299
|
+
}
|
|
1300
|
+
if high_vif:
|
|
1301
|
+
logging.info(
|
|
1302
|
+
"[Collinearity] %d variable(s) with VIF > %.1f: %s",
|
|
1303
|
+
len(high_vif), vif_threshold,
|
|
1304
|
+
', '.join(f"{v}={s}" for v, s in high_vif.items()),
|
|
1305
|
+
)
|
|
1306
|
+
except Exception as vif_err:
|
|
1307
|
+
logging.warning(
|
|
1308
|
+
"[Collinearity] VIF computation failed: %s", vif_err
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1311
|
+
except Exception as e:
|
|
1312
|
+
logging.warning("[Collinearity] Pre-computation failed: %s", e)
|
|
1313
|
+
|
|
1314
|
+
''' ---------------------------------------------------------- '''
|
|
1315
|
+
''' Function. Remove highly collinear variables from a list. '''
|
|
1316
|
+
''' Greedy approach: for each high-correlation pair remove the '''
|
|
1317
|
+
''' variable with the higher VIF (or second if VIF unavailable).'''
|
|
1318
|
+
''' Prespecified (protected) variables are never removed. '''
|
|
1319
|
+
''' ---------------------------------------------------------- '''
|
|
1320
|
+
def remove_collinear_vars(self, varlist, protected=None):
|
|
1321
|
+
"""
|
|
1322
|
+
Filter `varlist` to remove variables that are highly correlated with
|
|
1323
|
+
others or have excessive VIF, while preserving any `protected` variables.
|
|
1324
|
+
|
|
1325
|
+
Args:
|
|
1326
|
+
varlist (list): Candidate variable names.
|
|
1327
|
+
protected (set) : Variables that must not be removed.
|
|
1328
|
+
|
|
1329
|
+
Returns:
|
|
1330
|
+
list: Filtered variable list with collinear variables removed.
|
|
1331
|
+
"""
|
|
1332
|
+
if not varlist or self._corr_matrix is None:
|
|
1333
|
+
return varlist
|
|
1334
|
+
|
|
1335
|
+
protected = set(protected or [])
|
|
1336
|
+
protected |= set(getattr(self.param, 'ps_asvars', []))
|
|
1337
|
+
protected |= set(getattr(self.param, 'ps_isvars', []))
|
|
1338
|
+
|
|
1339
|
+
active = list(varlist)
|
|
1340
|
+
removed = set()
|
|
1341
|
+
|
|
1342
|
+
# ── Step 1: VIF-based removal ─────────────────────────────
|
|
1343
|
+
for var in list(active):
|
|
1344
|
+
if var in removed or var in protected:
|
|
1345
|
+
continue
|
|
1346
|
+
vif = self._vif_scores.get(var, 0.0)
|
|
1347
|
+
if vif > self._vif_threshold:
|
|
1348
|
+
removed.add(var)
|
|
1349
|
+
logging.info(
|
|
1350
|
+
"[CollinearityConstraint] Removed '%s' (VIF=%.1f > %.1f)",
|
|
1351
|
+
var, vif, self._vif_threshold,
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
# ── Step 2: Pairwise correlation removal ──────────────────
|
|
1355
|
+
for va, vb, r in self._high_corr_pairs:
|
|
1356
|
+
if va not in active or vb not in active:
|
|
1357
|
+
continue
|
|
1358
|
+
if va in removed or vb in removed:
|
|
1359
|
+
continue
|
|
1360
|
+
# Keep protected var; otherwise drop the higher-VIF one
|
|
1361
|
+
if vb in protected and va not in protected:
|
|
1362
|
+
drop = va
|
|
1363
|
+
elif va in protected and vb not in protected:
|
|
1364
|
+
drop = vb
|
|
1365
|
+
else:
|
|
1366
|
+
vif_a = self._vif_scores.get(va, 0.0)
|
|
1367
|
+
vif_b = self._vif_scores.get(vb, 0.0)
|
|
1368
|
+
drop = va if vif_a >= vif_b else vb
|
|
1369
|
+
|
|
1370
|
+
if drop not in protected:
|
|
1371
|
+
removed.add(drop)
|
|
1372
|
+
kept = vb if drop == va else va
|
|
1373
|
+
logging.info(
|
|
1374
|
+
"[CollinearityConstraint] Removed '%s' (|r|=%.4f with '%s')",
|
|
1375
|
+
drop, abs(r), kept,
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
filtered = [v for v in active if v not in removed]
|
|
1379
|
+
return filtered if filtered else list(varlist) # fallback: never return empty
|
|
1380
|
+
|
|
1381
|
+
''' ---------------------------------------------------------- '''
|
|
1382
|
+
''' Function. Check model prerequisites before fitting. '''
|
|
1383
|
+
''' Returns a list of warning strings (empty => all clear). '''
|
|
1384
|
+
''' ---------------------------------------------------------- '''
|
|
1385
|
+
def _check_model_prerequisites(self, all_vars, model_n=''):
|
|
1386
|
+
"""
|
|
1387
|
+
Inspect the design matrix for common problems that cause gradient-based
|
|
1388
|
+
optimisers to fail to converge:
|
|
1389
|
+
|
|
1390
|
+
1. Near-constant variables (variance ≈ 0)
|
|
1391
|
+
2. Extreme scale disparity between columns
|
|
1392
|
+
3. Near-singular design matrix (condition number)
|
|
1393
|
+
4. Insufficient observations-to-parameters ratio
|
|
1394
|
+
|
|
1395
|
+
Args:
|
|
1396
|
+
all_vars (list): Variable names in the design matrix.
|
|
1397
|
+
model_n (str) : Model type label (for logging).
|
|
1398
|
+
|
|
1399
|
+
Returns:
|
|
1400
|
+
list[str]: Diagnostic warning messages (empty list if none).
|
|
1401
|
+
"""
|
|
1402
|
+
warnings_out = []
|
|
1403
|
+
try:
|
|
1404
|
+
df = self.param.df
|
|
1405
|
+
cols = [v for v in all_vars if v in df.columns]
|
|
1406
|
+
if not cols:
|
|
1407
|
+
return warnings_out
|
|
1408
|
+
|
|
1409
|
+
X = df[cols].values.astype(float)
|
|
1410
|
+
n_obs, n_params = X.shape
|
|
1411
|
+
|
|
1412
|
+
# 1. Near-constant columns
|
|
1413
|
+
stds = X.std(axis=0)
|
|
1414
|
+
near_const = [cols[i] for i, s in enumerate(stds) if s < 1e-8]
|
|
1415
|
+
if near_const:
|
|
1416
|
+
msg = (
|
|
1417
|
+
f"[Prerequisite/{model_n}] Near-constant variable(s) detected "
|
|
1418
|
+
f"– may cause singular Hessian: {near_const}"
|
|
1419
|
+
)
|
|
1420
|
+
warnings_out.append(msg)
|
|
1421
|
+
logging.warning(msg)
|
|
1422
|
+
|
|
1423
|
+
# 2. Scale disparity
|
|
1424
|
+
col_ranges = X.max(axis=0) - X.min(axis=0)
|
|
1425
|
+
col_ranges[col_ranges == 0] = 1.0
|
|
1426
|
+
scale_ratio = col_ranges.max() / col_ranges.min()
|
|
1427
|
+
if scale_ratio > 1e4:
|
|
1428
|
+
msg = (
|
|
1429
|
+
f"[Prerequisite/{model_n}] Large scale disparity "
|
|
1430
|
+
f"(max/min range ratio = {scale_ratio:.1e}). "
|
|
1431
|
+
f"Consider standardising inputs to aid gradient convergence."
|
|
1432
|
+
)
|
|
1433
|
+
warnings_out.append(msg)
|
|
1434
|
+
logging.warning(msg)
|
|
1435
|
+
|
|
1436
|
+
# 3. Condition number (on standardised matrix)
|
|
1437
|
+
means = X.mean(axis=0)
|
|
1438
|
+
stds2 = X.std(axis=0); stds2[stds2 == 0] = 1.0
|
|
1439
|
+
Xz = (X - means) / stds2
|
|
1440
|
+
try:
|
|
1441
|
+
cond = np.linalg.cond(Xz)
|
|
1442
|
+
if cond > 1e6:
|
|
1443
|
+
msg = (
|
|
1444
|
+
f"[Prerequisite/{model_n}] Design matrix condition number "
|
|
1445
|
+
f"= {cond:.2e} (> 1e6). High collinearity is very likely "
|
|
1446
|
+
f"preventing gradient convergence."
|
|
1447
|
+
)
|
|
1448
|
+
warnings_out.append(msg)
|
|
1449
|
+
logging.warning(msg)
|
|
1450
|
+
elif cond > 1e3:
|
|
1451
|
+
logging.info(
|
|
1452
|
+
"[Prerequisite/%s] Moderate condition number = %.2e.", model_n, cond
|
|
1453
|
+
)
|
|
1454
|
+
except Exception:
|
|
1455
|
+
pass
|
|
1456
|
+
|
|
1457
|
+
# 4. Obs-to-parameters ratio
|
|
1458
|
+
n_cs = n_obs // max(len(self.param.choice_set), 1)
|
|
1459
|
+
if n_cs < n_params * 10:
|
|
1460
|
+
msg = (
|
|
1461
|
+
f"[Prerequisite/{model_n}] Low obs-to-params ratio "
|
|
1462
|
+
f"({n_cs} choice situations / {n_params} params). "
|
|
1463
|
+
f"Model may be overparameterised."
|
|
1464
|
+
)
|
|
1465
|
+
warnings_out.append(msg)
|
|
1466
|
+
logging.warning(msg)
|
|
1467
|
+
|
|
1468
|
+
except Exception as e:
|
|
1469
|
+
logging.debug("[_check_model_prerequisites] %s", e)
|
|
1470
|
+
|
|
1471
|
+
return warnings_out
|
|
1472
|
+
|
|
1473
|
+
''' ---------------------------------------------------------- '''
|
|
1474
|
+
''' Function. Diagnose why gradient optimisation failed to '''
|
|
1475
|
+
''' converge. Prints a structured diagnostic report to stdout. '''
|
|
1476
|
+
''' ---------------------------------------------------------- '''
|
|
1477
|
+
def _diagnose_nonconvergence(self, sol, model_n=''):
|
|
1478
|
+
"""
|
|
1479
|
+
Called after a model fails to converge. Analyses the candidate variable
|
|
1480
|
+
set and prints potential causes together with remediation suggestions.
|
|
1481
|
+
|
|
1482
|
+
Possible causes diagnosed:
|
|
1483
|
+
• Highly correlated predictors (from pre-computed correlation cache)
|
|
1484
|
+
• High VIF variables
|
|
1485
|
+
• Near-constant / near-zero-variance columns
|
|
1486
|
+
• Extreme scale differences
|
|
1487
|
+
• Ill-conditioned design matrix
|
|
1488
|
+
• Too many parameters relative to observations
|
|
1489
|
+
• Mixed-model specifics (draws, degenerate distributions)
|
|
1490
|
+
• RRM-specific advice
|
|
1491
|
+
|
|
1492
|
+
Args:
|
|
1493
|
+
sol (Solution): The non-converging solution.
|
|
1494
|
+
model_n (str) : Model type label for display.
|
|
1495
|
+
"""
|
|
1496
|
+
as_vars = sol.get('asvars', [])
|
|
1497
|
+
is_vars = sol.get('isvars', [])
|
|
1498
|
+
randvars = sol.get('randvars', {})
|
|
1499
|
+
all_vars = list(dict.fromkeys(as_vars + is_vars + list(randvars.keys())))
|
|
1500
|
+
all_vars = [v for v in self.param.varnames if v in all_vars]
|
|
1501
|
+
|
|
1502
|
+
label = model_n or sol.get('model_n', '?')
|
|
1503
|
+
sep = '─' * 62
|
|
1504
|
+
print(f"\n{sep}")
|
|
1505
|
+
print(f"[NonConvergence Diagnostic] model={label} sol#={sol.get('sol_num','?')}")
|
|
1506
|
+
print(f" Variables : {all_vars}")
|
|
1507
|
+
print(sep)
|
|
1508
|
+
|
|
1509
|
+
if not all_vars:
|
|
1510
|
+
print(" No variables – cannot diagnose."); print(sep); return
|
|
1511
|
+
|
|
1512
|
+
df = self.param.df
|
|
1513
|
+
cols = [v for v in all_vars if v in df.columns]
|
|
1514
|
+
if not cols:
|
|
1515
|
+
print(" Solution vars not found in dataframe."); print(sep); return
|
|
1516
|
+
|
|
1517
|
+
X = df[cols].values.astype(float)
|
|
1518
|
+
n_obs, n_params = X.shape
|
|
1519
|
+
n_cs = n_obs // max(len(self.param.choice_set), 1)
|
|
1520
|
+
issues = False
|
|
1521
|
+
|
|
1522
|
+
# 1. High-correlation pairs among solution variables
|
|
1523
|
+
if self._high_corr_pairs:
|
|
1524
|
+
sol_set = set(cols)
|
|
1525
|
+
relevant = [(a, b, r) for a, b, r in self._high_corr_pairs
|
|
1526
|
+
if a in sol_set and b in sol_set]
|
|
1527
|
+
if relevant:
|
|
1528
|
+
issues = True
|
|
1529
|
+
print(" ⚠ HIGH CORRELATION detected among solution variables:")
|
|
1530
|
+
for a, b, r in relevant:
|
|
1531
|
+
print(f" {a} <-> {b} |r| = {abs(r):.4f}")
|
|
1532
|
+
print(" → Remove one variable from each correlated pair, or use")
|
|
1533
|
+
print(" PCA / orthogonalisation to decorrelate predictors.")
|
|
1534
|
+
|
|
1535
|
+
# 2. High VIF
|
|
1536
|
+
high_vif_sol = {
|
|
1537
|
+
v: s for v, s in self._vif_scores.items()
|
|
1538
|
+
if v in cols and s > self._vif_threshold
|
|
1539
|
+
}
|
|
1540
|
+
if high_vif_sol:
|
|
1541
|
+
issues = True
|
|
1542
|
+
print(" ⚠ HIGH VIF variables in solution:")
|
|
1543
|
+
for v, s in high_vif_sol.items():
|
|
1544
|
+
print(f" {v} VIF = {s:.1f}")
|
|
1545
|
+
print(" → Remove or combine the above variables.")
|
|
1546
|
+
|
|
1547
|
+
# 3. Near-constant columns
|
|
1548
|
+
stds = X.std(axis=0)
|
|
1549
|
+
near_const = [cols[i] for i, s in enumerate(stds) if s < 1e-8]
|
|
1550
|
+
if near_const:
|
|
1551
|
+
issues = True
|
|
1552
|
+
print(f" ⚠ NEAR-CONSTANT variables (std ≈ 0): {near_const}")
|
|
1553
|
+
print(" → Remove them; they carry no information.")
|
|
1554
|
+
|
|
1555
|
+
# 4. Scale disparity
|
|
1556
|
+
col_ranges = X.max(axis=0) - X.min(axis=0)
|
|
1557
|
+
col_ranges[col_ranges == 0] = 1.0
|
|
1558
|
+
scale_ratio = col_ranges.max() / col_ranges.min()
|
|
1559
|
+
if scale_ratio > 1e4:
|
|
1560
|
+
issues = True
|
|
1561
|
+
print(f" ⚠ SCALE DISPARITY: max/min range ratio = {scale_ratio:.1e}")
|
|
1562
|
+
print(" → Standardise variables (zero mean, unit variance).")
|
|
1563
|
+
|
|
1564
|
+
# 5. Condition number
|
|
1565
|
+
means = X.mean(axis=0)
|
|
1566
|
+
stds2 = X.std(axis=0); stds2[stds2 == 0] = 1.0
|
|
1567
|
+
Xz = (X - means) / stds2
|
|
1568
|
+
try:
|
|
1569
|
+
cond = np.linalg.cond(Xz)
|
|
1570
|
+
if cond > 1e6:
|
|
1571
|
+
issues = True
|
|
1572
|
+
print(f" ⚠ ILL-CONDITIONED design matrix: cond# = {cond:.2e}")
|
|
1573
|
+
print(" → Gradient descent cannot navigate this landscape.")
|
|
1574
|
+
print(" Remedies: remove collinear vars, standardise data,")
|
|
1575
|
+
print(" increase ftol/gtol, or try a different solver.")
|
|
1576
|
+
except Exception:
|
|
1577
|
+
pass
|
|
1578
|
+
|
|
1579
|
+
# 6. Obs-to-parameters ratio
|
|
1580
|
+
total_params = n_params + len(randvars) * 2
|
|
1581
|
+
if n_cs < total_params * 5:
|
|
1582
|
+
issues = True
|
|
1583
|
+
print(f" ⚠ LOW OBS/PARAM RATIO: {n_cs} situations / {total_params} params")
|
|
1584
|
+
print(" → Reduce variables or random coefficients.")
|
|
1585
|
+
|
|
1586
|
+
# 7. Mixed-model specifics
|
|
1587
|
+
if label in ('mixed_logit', 'mixed_random_regret'):
|
|
1588
|
+
n_draws = getattr(self.param, 'n_draws', 0)
|
|
1589
|
+
if n_draws < 200:
|
|
1590
|
+
issues = True
|
|
1591
|
+
print(f" ⚠ LOW DRAW COUNT for {label}: n_draws = {n_draws}")
|
|
1592
|
+
print(" → Increase n_draws (≥ 500 recommended).")
|
|
1593
|
+
for var, distr in randvars.items():
|
|
1594
|
+
if var in df.columns and df[var].dropna().std() < 1e-6:
|
|
1595
|
+
issues = True
|
|
1596
|
+
print(f" ⚠ Random var '{var}' has near-zero variance in data.")
|
|
1597
|
+
print(f" Assigning distribution '{distr}' to a constant variable")
|
|
1598
|
+
print(" yields a degenerate likelihood surface.")
|
|
1599
|
+
|
|
1600
|
+
# 8. RRM-specific advice
|
|
1601
|
+
if label in ('random_regret', 'mixed_random_regret'):
|
|
1602
|
+
print(" ℹ RRM convergence tips:")
|
|
1603
|
+
print(" • Attributes should vary across alternatives.")
|
|
1604
|
+
print(" • Avoid variables identical across all alternatives.")
|
|
1605
|
+
print(" • Verify id/alt/choice column mapping.")
|
|
1606
|
+
|
|
1607
|
+
if not issues:
|
|
1608
|
+
print(" ℹ No obvious collinearity / scale issues detected.")
|
|
1609
|
+
print(" Other possible causes: flat likelihood, poor starting values,")
|
|
1610
|
+
print(" insufficient iterations (maxiter), or numerical overflow in")
|
|
1611
|
+
print(" exp() transforms. Try increasing maxiter or tightening")
|
|
1612
|
+
print(" ftol/gtol, or supplying better init_coeff.")
|
|
1613
|
+
|
|
1614
|
+
print(sep + "\n")
|
|
1615
|
+
|
|
1211
1616
|
''' ---------------------------------------------------------- '''
|
|
1212
1617
|
''' Function. Remove redundant variables from a list. '''
|
|
1213
1618
|
''' Ensure unique variables do not exist in different forms '''
|
|
@@ -1472,6 +1877,15 @@ class Search():
|
|
|
1472
1877
|
isvars = self.select_isvars()
|
|
1473
1878
|
|
|
1474
1879
|
|
|
1880
|
+
# ── Collinearity constraint: remove highly correlated / high-VIF vars
|
|
1881
|
+
# Protected vars (ps_asvars) are never dropped by this filter.
|
|
1882
|
+
asvars = self.remove_collinear_vars(asvars)
|
|
1883
|
+
# Ensure we still have at least one variable after filtering
|
|
1884
|
+
while (len(asvars) + len(isvars)) < 1:
|
|
1885
|
+
asvars = self.select_asvars()
|
|
1886
|
+
isvars = self.select_isvars()
|
|
1887
|
+
asvars = self.remove_collinear_vars(asvars)
|
|
1888
|
+
|
|
1475
1889
|
randvars = self.select_randvars(asvars)
|
|
1476
1890
|
bcvars, bctrans = self.select_bcvars(asvars)
|
|
1477
1891
|
cor, corvars = self.select_corvars(randvars, bcvars)
|
|
@@ -2866,6 +3280,8 @@ class Search():
|
|
|
2866
3280
|
# {
|
|
2867
3281
|
self.not_converged += 1
|
|
2868
3282
|
sol['converged'] = False
|
|
3283
|
+
# ── Convergence diagnostic: explain why the model did not converge
|
|
3284
|
+
self._diagnose_nonconvergence(sol, model_n=sol.get('model_n', ''))
|
|
2869
3285
|
# }
|
|
2870
3286
|
|
|
2871
3287
|
if self.param.verbose:
|
|
@@ -3532,6 +3948,15 @@ class Search():
|
|
|
3532
3948
|
def evaluate_model(self, sol):
|
|
3533
3949
|
# {
|
|
3534
3950
|
model_n = sol.get('model_n', '')
|
|
3951
|
+
|
|
3952
|
+
# ── Pre-fit collinearity / prerequisite check ────────────────
|
|
3953
|
+
as_vars = sol.get('asvars', [])
|
|
3954
|
+
is_vars = sol.get('isvars', [])
|
|
3955
|
+
randvars = sol.get('randvars', {})
|
|
3956
|
+
_all_chk = list(dict.fromkeys(as_vars + is_vars + list(randvars.keys())))
|
|
3957
|
+
self._check_model_prerequisites(_all_chk, model_n)
|
|
3958
|
+
# ─────────────────────────────────────────────────────────────
|
|
3959
|
+
|
|
3535
3960
|
if model_n == 'random_regret':
|
|
3536
3961
|
return self.evaluate_rrm(sol)
|
|
3537
3962
|
elif model_n == 'mixed_random_regret':
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from scipy.optimize import minimize
|
|
7
|
+
from scipy.stats import norm, t as student_t
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jax.scipy.special import ndtr as jax_ndtr
|
|
13
|
+
except ImportError: # pragma: no cover
|
|
14
|
+
jax = None
|
|
15
|
+
jnp = None
|
|
16
|
+
jax_ndtr = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from ._choice_model import DiscreteChoiceModel
|
|
20
|
+
except ImportError:
|
|
21
|
+
from _choice_model import DiscreteChoiceModel
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BinaryProbit(DiscreteChoiceModel):
|
|
25
|
+
"""Binary probit estimated with JAX autodiff and scipy L-BFGS-B."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, _jax=False):
|
|
28
|
+
super(BinaryProbit, self).__init__(_jax)
|
|
29
|
+
self.descr = "Binary Probit"
|
|
30
|
+
self.result = None
|
|
31
|
+
self._X_design = None
|
|
32
|
+
|
|
33
|
+
def setup(self, X, y, varnames=None, fit_intercept=True):
|
|
34
|
+
X = np.asarray(X)
|
|
35
|
+
y = np.asarray(y).reshape(-1)
|
|
36
|
+
if varnames is None:
|
|
37
|
+
varnames = [f"x{i}" for i in range(X.shape[1])]
|
|
38
|
+
self.X = X
|
|
39
|
+
self.y = y
|
|
40
|
+
self.varnames = np.asarray(varnames, dtype="<U64")
|
|
41
|
+
self.fit_intercept = bool(fit_intercept)
|
|
42
|
+
self.sample_size = int(X.shape[0])
|
|
43
|
+
if self.fit_intercept:
|
|
44
|
+
self._X_design = np.column_stack([np.ones((X.shape[0], 1)), X])
|
|
45
|
+
self._design_names = np.asarray(["intercept", *self.varnames], dtype="<U64")
|
|
46
|
+
else:
|
|
47
|
+
self._X_design = X.copy()
|
|
48
|
+
self._design_names = self.varnames.copy()
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def _negloglik_jax(self, params, X, y):
|
|
52
|
+
xb = X @ params
|
|
53
|
+
p = jnp.clip(jax_ndtr(xb), 1e-10, 1.0 - 1e-10)
|
|
54
|
+
ll = y * jnp.log(p) + (1.0 - y) * jnp.log(1.0 - p)
|
|
55
|
+
return -jnp.sum(ll)
|
|
56
|
+
|
|
57
|
+
def fit(self, disp=False, **fit_kwargs):
|
|
58
|
+
if jax is None or jnp is None or jax_ndtr is None:
|
|
59
|
+
raise ImportError("JAX is required for BinaryProbit")
|
|
60
|
+
|
|
61
|
+
X = jnp.asarray(self._X_design)
|
|
62
|
+
y = jnp.asarray(self.y)
|
|
63
|
+
init = np.zeros(X.shape[1], dtype=float)
|
|
64
|
+
|
|
65
|
+
val_grad = jax.jit(jax.value_and_grad(self._negloglik_jax))
|
|
66
|
+
|
|
67
|
+
def _obj(params_np):
|
|
68
|
+
val, grad = val_grad(jnp.asarray(params_np), X, y)
|
|
69
|
+
return float(val), np.asarray(grad, dtype=float)
|
|
70
|
+
|
|
71
|
+
res = minimize(
|
|
72
|
+
fun=lambda p: _obj(p)[0],
|
|
73
|
+
x0=init,
|
|
74
|
+
jac=lambda p: _obj(p)[1],
|
|
75
|
+
method="L-BFGS-B",
|
|
76
|
+
options={"disp": bool(disp), "maxiter": int(fit_kwargs.pop("maxiter", 1000))},
|
|
77
|
+
)
|
|
78
|
+
self.result = res
|
|
79
|
+
self.coeff_names = self._design_names.copy()
|
|
80
|
+
self.coeff_est = np.asarray(res.x, dtype=float)
|
|
81
|
+
self.loglik = float(-res.fun)
|
|
82
|
+
self.converged = bool(res.success)
|
|
83
|
+
self.total_fun_eval = int(getattr(res, "nfev", 0))
|
|
84
|
+
|
|
85
|
+
hess_inv = getattr(res, "hess_inv", None)
|
|
86
|
+
if hess_inv is not None:
|
|
87
|
+
if hasattr(hess_inv, "todense"):
|
|
88
|
+
cov = np.asarray(hess_inv.todense(), dtype=float)
|
|
89
|
+
else:
|
|
90
|
+
cov = np.asarray(hess_inv, dtype=float)
|
|
91
|
+
stderr = np.sqrt(np.clip(np.diag(cov), 1e-12, None))
|
|
92
|
+
else:
|
|
93
|
+
stderr = np.full_like(self.coeff_est, np.nan, dtype=float)
|
|
94
|
+
|
|
95
|
+
self.stderr = stderr
|
|
96
|
+
self.zvalues = self.coeff_est / np.where(stderr > 0, stderr, np.nan)
|
|
97
|
+
self.pvalues = 2.0 * (1.0 - norm.cdf(np.abs(self.zvalues)))
|
|
98
|
+
k = len(self.coeff_est)
|
|
99
|
+
n = max(int(self.sample_size), 1)
|
|
100
|
+
self.aic = float(2 * k - 2 * self.loglik)
|
|
101
|
+
self.bic = float(k * np.log(n) - 2 * self.loglik)
|
|
102
|
+
return res
|
|
103
|
+
|
|
104
|
+
def predict_proba(self, X=None):
|
|
105
|
+
if self.coeff_est is None:
|
|
106
|
+
raise RuntimeError("BinaryProbit must be fit before prediction")
|
|
107
|
+
X_arr = self.X if X is None else np.asarray(X)
|
|
108
|
+
if self.fit_intercept:
|
|
109
|
+
X_arr = np.column_stack([np.ones((X_arr.shape[0], 1)), X_arr])
|
|
110
|
+
xb = X_arr @ self.coeff_est
|
|
111
|
+
return norm.cdf(xb)
|
|
112
|
+
|
|
113
|
+
def summary_frame(self):
|
|
114
|
+
if self.coeff_est is None:
|
|
115
|
+
return pd.DataFrame()
|
|
116
|
+
return pd.DataFrame({
|
|
117
|
+
"coef": self.coeff_est,
|
|
118
|
+
"stderr": self.stderr,
|
|
119
|
+
"z": self.zvalues,
|
|
120
|
+
"pvalue": self.pvalues,
|
|
121
|
+
}, index=self.coeff_names)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class _OLSResult:
|
|
126
|
+
params: pd.Series
|
|
127
|
+
bse: pd.Series
|
|
128
|
+
tvalues: pd.Series
|
|
129
|
+
pvalues: pd.Series
|
|
130
|
+
llf: float
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class HeckmanTwoStep(DiscreteChoiceModel):
|
|
134
|
+
"""Heckman selection model using JAX probit + closed-form OLS second stage."""
|
|
135
|
+
|
|
136
|
+
def __init__(self, _jax=False):
|
|
137
|
+
super(HeckmanTwoStep, self).__init__(_jax)
|
|
138
|
+
self.descr = "Heckman Two-Step"
|
|
139
|
+
self.selection_result = None
|
|
140
|
+
self.outcome_result = None
|
|
141
|
+
self.params_table = pd.DataFrame()
|
|
142
|
+
|
|
143
|
+
def setup(
|
|
144
|
+
self,
|
|
145
|
+
selection_X,
|
|
146
|
+
selection_y,
|
|
147
|
+
outcome_X,
|
|
148
|
+
outcome_y,
|
|
149
|
+
selection_varnames=None,
|
|
150
|
+
outcome_varnames=None,
|
|
151
|
+
fit_intercept=True,
|
|
152
|
+
):
|
|
153
|
+
selection_X = np.asarray(selection_X)
|
|
154
|
+
selection_y = np.asarray(selection_y).reshape(-1)
|
|
155
|
+
outcome_X = np.asarray(outcome_X)
|
|
156
|
+
outcome_y = np.asarray(outcome_y).reshape(-1)
|
|
157
|
+
if selection_varnames is None:
|
|
158
|
+
selection_varnames = [f"s{i}" for i in range(selection_X.shape[1])]
|
|
159
|
+
if outcome_varnames is None:
|
|
160
|
+
outcome_varnames = [f"o{i}" for i in range(outcome_X.shape[1])]
|
|
161
|
+
self.selection_X = selection_X
|
|
162
|
+
self.selection_y = selection_y
|
|
163
|
+
self.outcome_X = outcome_X
|
|
164
|
+
self.outcome_y = outcome_y
|
|
165
|
+
self.selection_varnames = np.asarray(selection_varnames, dtype="<U64")
|
|
166
|
+
self.outcome_varnames = np.asarray(outcome_varnames, dtype="<U64")
|
|
167
|
+
self.fit_intercept = bool(fit_intercept)
|
|
168
|
+
self.sample_size = int(selection_X.shape[0])
|
|
169
|
+
return self
|
|
170
|
+
|
|
171
|
+
def fit(self, disp=False, **fit_kwargs):
|
|
172
|
+
sel_X = np.asarray(self.selection_X, dtype=float)
|
|
173
|
+
out_X = np.asarray(self.outcome_X, dtype=float)
|
|
174
|
+
if self.fit_intercept:
|
|
175
|
+
sel_X = np.column_stack([np.ones((sel_X.shape[0], 1)), sel_X])
|
|
176
|
+
out_X = np.column_stack([np.ones((out_X.shape[0], 1)), out_X])
|
|
177
|
+
|
|
178
|
+
probit_model = BinaryProbit(_jax=True)
|
|
179
|
+
sel_names = (["intercept"] if self.fit_intercept else []) + list(self.selection_varnames)
|
|
180
|
+
probit_model.setup(sel_X[:, 1:] if self.fit_intercept else sel_X,
|
|
181
|
+
self.selection_y,
|
|
182
|
+
varnames=sel_names[1:] if self.fit_intercept else sel_names,
|
|
183
|
+
fit_intercept=self.fit_intercept)
|
|
184
|
+
probit_model.fit(disp=disp, **fit_kwargs)
|
|
185
|
+
|
|
186
|
+
xb = sel_X @ probit_model.coeff_est
|
|
187
|
+
mills = norm.pdf(xb) / np.clip(norm.cdf(xb), 1e-10, None)
|
|
188
|
+
|
|
189
|
+
mask = self.selection_y == 1
|
|
190
|
+
out_design = np.column_stack([out_X[mask], mills[mask]])
|
|
191
|
+
out_y = self.outcome_y[mask]
|
|
192
|
+
|
|
193
|
+
xtx = out_design.T @ out_design
|
|
194
|
+
xtx_inv = np.linalg.pinv(xtx)
|
|
195
|
+
beta = xtx_inv @ (out_design.T @ out_y)
|
|
196
|
+
resid = out_y - out_design @ beta
|
|
197
|
+
dof = max(out_design.shape[0] - out_design.shape[1], 1)
|
|
198
|
+
sigma2 = float((resid @ resid) / dof)
|
|
199
|
+
cov = sigma2 * xtx_inv
|
|
200
|
+
se = np.sqrt(np.clip(np.diag(cov), 1e-12, None))
|
|
201
|
+
tvals = beta / np.where(se > 0, se, np.nan)
|
|
202
|
+
pvals = 2.0 * (1.0 - student_t.cdf(np.abs(tvals), df=dof))
|
|
203
|
+
ll_ols = -0.5 * out_design.shape[0] * (math.log(2.0 * math.pi * sigma2) + 1.0)
|
|
204
|
+
|
|
205
|
+
out_names = (["intercept"] if self.fit_intercept else []) + list(self.outcome_varnames) + ["IMR"]
|
|
206
|
+
ols = _OLSResult(
|
|
207
|
+
params=pd.Series(beta, index=out_names),
|
|
208
|
+
bse=pd.Series(se, index=out_names),
|
|
209
|
+
tvalues=pd.Series(tvals, index=out_names),
|
|
210
|
+
pvalues=pd.Series(pvals, index=out_names),
|
|
211
|
+
llf=float(ll_ols),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
self.selection_result = probit_model
|
|
215
|
+
self.outcome_result = ols
|
|
216
|
+
self.loglik = float(probit_model.loglik + ll_ols)
|
|
217
|
+
total_k = len(probit_model.coeff_est) + len(beta)
|
|
218
|
+
self.aic = float(2 * total_k - 2 * self.loglik)
|
|
219
|
+
self.bic = float(total_k * np.log(max(self.sample_size, 1)) - 2 * self.loglik)
|
|
220
|
+
self.converged = bool(probit_model.converged)
|
|
221
|
+
|
|
222
|
+
selection_tbl = pd.DataFrame({
|
|
223
|
+
"coef": probit_model.coeff_est,
|
|
224
|
+
"stderr": probit_model.stderr,
|
|
225
|
+
"z": probit_model.zvalues,
|
|
226
|
+
"pvalue": probit_model.pvalues,
|
|
227
|
+
}, index=probit_model.coeff_names)
|
|
228
|
+
outcome_tbl = pd.DataFrame({
|
|
229
|
+
"coef": ols.params,
|
|
230
|
+
"stderr": ols.bse,
|
|
231
|
+
"z": ols.tvalues,
|
|
232
|
+
"pvalue": ols.pvalues,
|
|
233
|
+
})
|
|
234
|
+
self.params_table = pd.concat(
|
|
235
|
+
{"selection": selection_tbl, "outcome": outcome_tbl},
|
|
236
|
+
names=["equation", "term"],
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
coeff_names = [f"selection::{name}" for name in selection_tbl.index]
|
|
240
|
+
coeff_names += [f"outcome::{name}" for name in outcome_tbl.index]
|
|
241
|
+
self.coeff_names = np.asarray(coeff_names, dtype="<U128")
|
|
242
|
+
self.coeff_est = np.concatenate([selection_tbl["coef"].values, outcome_tbl["coef"].values])
|
|
243
|
+
self.stderr = np.concatenate([selection_tbl["stderr"].values, outcome_tbl["stderr"].values])
|
|
244
|
+
self.zvalues = np.concatenate([selection_tbl["z"].values, outcome_tbl["z"].values])
|
|
245
|
+
self.pvalues = np.concatenate([selection_tbl["pvalue"].values, outcome_tbl["pvalue"].values])
|
|
246
|
+
return {"probit": probit_model, "ols": ols}
|
|
247
|
+
|
|
248
|
+
def predict_selection_proba(self, X=None):
|
|
249
|
+
if self.selection_result is None:
|
|
250
|
+
raise RuntimeError("HeckmanTwoStep must be fit before prediction")
|
|
251
|
+
X_arr = self.selection_X if X is None else np.asarray(X)
|
|
252
|
+
return self.selection_result.predict_proba(X_arr)
|
|
253
|
+
|
|
254
|
+
def predict_outcome(self, X=None, selection_probability=None):
|
|
255
|
+
if self.outcome_result is None:
|
|
256
|
+
raise RuntimeError("HeckmanTwoStep must be fit before prediction")
|
|
257
|
+
X_arr = self.outcome_X if X is None else np.asarray(X)
|
|
258
|
+
if self.fit_intercept:
|
|
259
|
+
X_arr = np.column_stack([np.ones((X_arr.shape[0], 1)), X_arr])
|
|
260
|
+
if selection_probability is None:
|
|
261
|
+
selection_probability = np.clip(self.predict_selection_proba(), 1e-10, 1 - 1e-10)
|
|
262
|
+
xb = norm.ppf(np.clip(selection_probability, 1e-10, 1 - 1e-10))
|
|
263
|
+
imr = norm.pdf(xb) / np.clip(norm.cdf(xb), 1e-10, None)
|
|
264
|
+
X_aug = np.column_stack([X_arr, imr])
|
|
265
|
+
return X_aug @ self.outcome_result.params.values
|
|
266
|
+
|
|
267
|
+
def summary_frame(self):
|
|
268
|
+
return self.params_table.copy()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.84
|
|
@@ -28,6 +28,7 @@ src/SearchLibrium/ordered_logit.py
|
|
|
28
28
|
src/SearchLibrium/ordered_logit_mixed.py
|
|
29
29
|
src/SearchLibrium/rrm.py
|
|
30
30
|
src/SearchLibrium/search.py
|
|
31
|
+
src/SearchLibrium/selection_models.py
|
|
31
32
|
src/SearchLibrium/setup.py
|
|
32
33
|
src/SearchLibrium/siman.py
|
|
33
34
|
src/SearchLibrium/threshold.py
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
0.0.72
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{searchlibrium-0.0.72 → searchlibrium-0.0.84}/src/SearchLibrium.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|