bossanova 0.1.0.dev15__tar.gz → 0.1.0.dev16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/PKG-INFO +1 -1
  2. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/state.py +4 -0
  3. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/state.py +3 -1
  4. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/design/reference.py +93 -0
  5. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/compute.py +31 -1
  6. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/emm.py +190 -31
  7. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/slopes.py +23 -11
  8. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/model/core.py +4 -0
  9. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/pyproject.toml +4 -2
  10. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/.gitignore +0 -0
  11. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/LICENSE +0 -0
  12. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/README.md +0 -0
  13. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/__init__.py +0 -0
  14. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/README.md +0 -0
  15. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/__init__.py +0 -0
  16. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/advertising.csv +0 -0
  17. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/cake.csv +0 -0
  18. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/chickweight.csv +0 -0
  19. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/credit.csv +0 -0
  20. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/gammas.csv +0 -0
  21. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/mtcars.csv +0 -0
  22. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/penguins.csv +0 -0
  23. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/poker.csv +0 -0
  24. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/sleep.csv +0 -0
  25. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/titanic.csv +0 -0
  26. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/titanic_test.csv +0 -0
  27. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/data/titanic_train.csv +0 -0
  28. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/distributions/__init__.py +0 -0
  29. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/distributions/continuous.py +0 -0
  30. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/distributions/discrete.py +0 -0
  31. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/distributions/varying.py +0 -0
  32. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/expressions.py +0 -0
  33. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/__init__.py +0 -0
  34. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/__init__.py +0 -0
  35. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/__init__.py +0 -0
  36. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/data.py +0 -0
  37. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/dataframes.py +0 -0
  38. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/resamples.py +0 -0
  39. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/results.py +0 -0
  40. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/builders/specs.py +0 -0
  41. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/schemas.py +0 -0
  42. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/__init__.py +0 -0
  43. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/data.py +0 -0
  44. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/display.py +0 -0
  45. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/explore.py +0 -0
  46. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/formula.py +0 -0
  47. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/structs/specs.py +0 -0
  48. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/containers/validators.py +0 -0
  49. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/__init__.py +0 -0
  50. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/backend/__init__.py +0 -0
  51. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/backend/dispatch.py +0 -0
  52. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/backend/jax.py +0 -0
  53. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/backend/numpy.py +0 -0
  54. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/backend/protocol.py +0 -0
  55. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/batching.py +0 -0
  56. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/config.py +0 -0
  57. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/convergence.py +0 -0
  58. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/design/__init__.py +0 -0
  59. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/design/coding.py +0 -0
  60. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/design/names.py +0 -0
  61. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/design/z_matrix.py +0 -0
  62. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/differentiation.py +0 -0
  63. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/__init__.py +0 -0
  64. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/algebra.py +0 -0
  65. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/base.py +0 -0
  66. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/core.py +0 -0
  67. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/derived.py +0 -0
  68. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/factories.py +0 -0
  69. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/plotting.py +0 -0
  70. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/distributions/probability.py +0 -0
  71. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/__init__.py +0 -0
  72. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/binomial.py +0 -0
  73. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/create.py +0 -0
  74. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/gamma.py +0 -0
  75. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/gaussian.py +0 -0
  76. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/links.py +0 -0
  77. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/poisson.py +0 -0
  78. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/response.py +0 -0
  79. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/schema.py +0 -0
  80. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/family/tdist.py +0 -0
  81. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/__init__.py +0 -0
  82. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/contrasts.py +0 -0
  83. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/diagnostics.py +0 -0
  84. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/estimation.py +0 -0
  85. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/hypothesis.py +0 -0
  86. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/information_criteria.py +0 -0
  87. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/multiplicity.py +0 -0
  88. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/profile.py +0 -0
  89. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/sandwich.py +0 -0
  90. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/satterthwaite.py +0 -0
  91. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/wald_variance.py +0 -0
  92. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/inference/welch.py +0 -0
  93. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/linalg/__init__.py +0 -0
  94. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/linalg/qr.py +0 -0
  95. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/linalg/schur.py +0 -0
  96. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/linalg/sparse.py +0 -0
  97. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/linalg/svd.py +0 -0
  98. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/predict.py +0 -0
  99. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/rng.py +0 -0
  100. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/rounding.py +0 -0
  101. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/__init__.py +0 -0
  102. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/glm.py +0 -0
  103. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/glmer.py +0 -0
  104. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/heuristics.py +0 -0
  105. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/initialization.py +0 -0
  106. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/lambda_builder.py +0 -0
  107. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/lambda_sparse.py +0 -0
  108. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/lambda_template.py +0 -0
  109. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/lmer.py +0 -0
  110. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/optimize.py +0 -0
  111. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/pirls_sparse.py +0 -0
  112. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/solvers/quadrature.py +0 -0
  113. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/tolerances.py +0 -0
  114. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/transforms.py +0 -0
  115. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/variance.py +0 -0
  116. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/maths/weights.py +0 -0
  117. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/__init__.py +0 -0
  118. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/bundle.py +0 -0
  119. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/common/__init__.py +0 -0
  120. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/common/contrast_registry.py +0 -0
  121. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/common/data_utils.py +0 -0
  122. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/common/factors.py +0 -0
  123. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/common/formula_utils.py +0 -0
  124. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/__init__.py +0 -0
  125. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/compare.py +0 -0
  126. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/cv.py +0 -0
  127. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/deviance.py +0 -0
  128. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/f_test.py +0 -0
  129. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/helpers.py +0 -0
  130. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/lrt.py +0 -0
  131. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/lrt_compare.py +0 -0
  132. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/compare/refit.py +0 -0
  133. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/contrasts.py +0 -0
  134. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/convergence.py +0 -0
  135. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/diagnostics.py +0 -0
  136. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/__init__.py +0 -0
  137. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/dispatch.py +0 -0
  138. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/glm.py +0 -0
  139. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/glmer.py +0 -0
  140. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/lmer.py +0 -0
  141. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/ols.py +0 -0
  142. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/fit/rank.py +0 -0
  143. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/__init__.py +0 -0
  144. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/design.py +0 -0
  145. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/encoding.py +0 -0
  146. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/evaluate.py +0 -0
  147. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/evaluate_newdata.py +0 -0
  148. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/evaluate_transforms.py +0 -0
  149. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/helpers.py +0 -0
  150. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/parse.py +0 -0
  151. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/parser/__init__.py +0 -0
  152. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/parser/expr.py +0 -0
  153. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/parser/parser.py +0 -0
  154. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/parser/scanner.py +0 -0
  155. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/parser/token.py +0 -0
  156. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/formula/random_effects.py +0 -0
  157. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/__init__.py +0 -0
  158. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/asymptotic.py +0 -0
  159. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/bootstrap.py +0 -0
  160. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/cv.py +0 -0
  161. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/mee.py +0 -0
  162. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/params.py +0 -0
  163. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/permutation.py +0 -0
  164. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/prediction.py +0 -0
  165. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/profile.py +0 -0
  166. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/resample_bundle.py +0 -0
  167. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/satterthwaite_emm.py +0 -0
  168. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/infer/simulation.py +0 -0
  169. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/__init__.py +0 -0
  170. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/bracket_contrasts.py +0 -0
  171. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/conditions.py +0 -0
  172. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/contrasts.py +0 -0
  173. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/explore.py +0 -0
  174. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/explore_parser.py +0 -0
  175. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/explore_scanner.py +0 -0
  176. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/grid.py +0 -0
  177. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/inference.py +0 -0
  178. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/joint_tests.py +0 -0
  179. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/marginal/validation.py +0 -0
  180. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/predict.py +0 -0
  181. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/profile.py +0 -0
  182. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/rendering/__init__.py +0 -0
  183. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/rendering/latex.py +0 -0
  184. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/__init__.py +0 -0
  185. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/common.py +0 -0
  186. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/core.py +0 -0
  187. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/glm.py +0 -0
  188. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/glmer.py +0 -0
  189. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/lm.py +0 -0
  190. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/lm_bca.py +0 -0
  191. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/lm_operators.py +0 -0
  192. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/lmer.py +0 -0
  193. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/mixed.py +0 -0
  194. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/results.py +0 -0
  195. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/resample/utils.py +0 -0
  196. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/__init__.py +0 -0
  197. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/dgp/__init__.py +0 -0
  198. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/dgp/generate.py +0 -0
  199. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/dgp/glm.py +0 -0
  200. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/dgp/glmer.py +0 -0
  201. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/dgp/lm.py +0 -0
  202. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/dgp/lmer.py +0 -0
  203. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/harness.py +0 -0
  204. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/metrics.py +0 -0
  205. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/model_sim.py +0 -0
  206. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/simulation/power.py +0 -0
  207. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/transforms.py +0 -0
  208. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/operations/varying.py +0 -0
  209. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/README.md +0 -0
  210. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/__init__.py +0 -0
  211. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/cognition.py +0 -0
  212. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/compare.py +0 -0
  213. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/core.py +0 -0
  214. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/core_data.py +0 -0
  215. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/core_protocols.py +0 -0
  216. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/core_sizing.py +0 -0
  217. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/core_viz.py +0 -0
  218. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/dag.py +0 -0
  219. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/design.py +0 -0
  220. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/fit.py +0 -0
  221. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/fit_builders.py +0 -0
  222. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/fit_layers.py +0 -0
  223. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/helpers.py +0 -0
  224. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/lattice.py +0 -0
  225. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/layout.py +0 -0
  226. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/mem.py +0 -0
  227. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/params.py +0 -0
  228. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/predict.py +0 -0
  229. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/profile.py +0 -0
  230. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/ranef.py +0 -0
  231. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/relationships.py +0 -0
  232. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/resamples.py +0 -0
  233. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/resid.py +0 -0
  234. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/internal/viz/vif.py +0 -0
  235. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/model/__init__.py +0 -0
  236. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/model/guards.py +0 -0
  237. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/model/summary.py +0 -0
  238. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/bossanova/py.typed +0 -0
  239. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/tests/bossanova_benchmarks/bootstrap/data/README.md +0 -0
  240. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/tests/bossanova_benchmarks/insteval/data/README.md +0 -0
  241. {bossanova-0.1.0.dev15 → bossanova-0.1.0.dev16}/tests/bossanova_tests/hypothesis/README.md +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bossanova
3
- Version: 0.1.0.dev15
3
+ Version: 0.1.0.dev16
4
4
  Summary: Bridging statistical cultures with some jazz
5
5
  Author: Eshin Jolly
6
6
  License-Expression: MIT
@@ -197,6 +197,7 @@ def build_mee_state(
197
197
  mee_type: str,
198
198
  *,
199
199
  units: str = "link",
200
+ weights: str = "equal",
200
201
  L_matrix: np.ndarray | None = None,
201
202
  contrast_method: str | None = None,
202
203
  n_contrast_levels: int | None = None,
@@ -219,6 +220,8 @@ def build_mee_state(
219
220
  focal_var: The primary variable being explored.
220
221
  mee_type: Type of effect ("means", "slopes", "contrasts").
221
222
  units: Scale of the estimates: "link" (linear predictor) or "data" (response).
223
+ weights: Averaging philosophy: ``"equal"`` (balanced reference grid,
224
+ emmeans-style) or ``"observed"`` (g-computation over observed data).
222
225
  L_matrix: Design matrix for delta method inference (optional).
223
226
  Shape (n_estimates, n_coef). For EMMs this is X_ref.
224
227
  contrast_method: Original contrast type for multiplicity adjustment
@@ -258,6 +261,7 @@ def build_mee_state(
258
261
  focal_var=focal_var,
259
262
  type=mee_type,
260
263
  units=units,
264
+ weights=weights,
261
265
  L_matrix=L_matrix,
262
266
  contrast_method=contrast_method,
263
267
  n_contrast_levels=n_contrast_levels,
@@ -258,7 +258,9 @@ class MeeState:
258
258
  focal_var: str = field(validator=validators.instance_of(str))
259
259
  type: str = field(validator=validators.in_(("means", "slopes", "contrasts")))
260
260
  units: str = field(default="link", validator=validators.in_(("link", "data")))
261
- weights: str = field(default="equal", validator=validators.in_(("equal", "observed")))
261
+ weights: str = field(
262
+ default="equal", validator=validators.in_(("equal", "observed"))
263
+ )
262
264
 
263
265
  # Design matrix for delta method inference (optional)
264
266
  # Shape: (n_estimates, n_coef). For EMMs, this is X_ref; for slopes, a selector row.
@@ -9,6 +9,7 @@ from bossanova.internal.maths.design.names import (
9
9
 
10
10
  __all__ = [
11
11
  "build_continuous_reference_matrix",
12
+ "build_counterfactual_design_matrices",
12
13
  "build_reference_design_matrix",
13
14
  "build_reference_row",
14
15
  ]
@@ -248,6 +249,98 @@ def build_continuous_reference_matrix(
248
249
  return X_ref
249
250
 
250
251
 
252
+ def build_counterfactual_design_matrices(
253
+ X: np.ndarray,
254
+ X_names: tuple[str, ...] | list[str],
255
+ focal_var: str,
256
+ levels: list[str],
257
+ ) -> list[np.ndarray]:
258
+ """Build counterfactual design matrices for g-computation.
259
+
260
+ For each focal level, creates a modified copy of the full design matrix
261
+ ``X`` where the focal variable's indicator columns are set to match that
262
+ level and interaction columns involving the focal variable are recomputed
263
+ from component values. Non-focal columns are left unchanged, preserving
264
+ the observed covariate distribution.
265
+
266
+ This is the core building block for ``weights="observed"``
267
+ (g-computation / counterfactual prediction). Each returned matrix answers:
268
+ "What would the design matrix look like if every observation were assigned
269
+ to this focal level?"
270
+
271
+ Args:
272
+ X: Original design matrix, shape ``(N, p)``.
273
+ X_names: Column names from the design matrix, in order.
274
+ focal_var: Name of the categorical focal variable.
275
+ levels: List of levels to compute counterfactuals for.
276
+
277
+ Returns:
278
+ List of counterfactual design matrices (one per level), each shape
279
+ ``(N, p)``. Order matches ``levels``.
280
+
281
+ Examples:
282
+ For a model ``y ~ treatment + x + treatment:x`` with
283
+ ``X_names = ("Intercept", "x", "treatment[B]", "x:treatment[B]")``::
284
+
285
+ mats = build_counterfactual_design_matrices(X, X_names, "treatment", ["ref", "B"])
286
+ # mats[0]: treatment set to ref for all rows (treatment[B]=0, x:treatment[B]=0)
287
+ # mats[1]: treatment set to B for all rows (treatment[B]=1, x:treatment[B]=x_i)
288
+ """
289
+ X_names_list = list(X_names)
290
+ N, p = X.shape
291
+
292
+ # Pre-parse column metadata once
293
+ col_infos = [parse_design_column_name(name) for name in X_names_list]
294
+
295
+ # Identify focal dummy columns and their levels
296
+ focal_col_indices: list[int] = []
297
+ focal_col_levels: list[str | None] = []
298
+ for j, info in enumerate(col_infos):
299
+ if info.column_type == "categorical" and info.base_term == focal_var:
300
+ focal_col_indices.append(j)
301
+ focal_col_levels.append(info.level)
302
+
303
+ # Identify interaction columns involving the focal variable
304
+ interaction_cols: list[int] = []
305
+ for j, info in enumerate(col_infos):
306
+ if info.is_interaction:
307
+ parts = X_names_list[j].split(":")
308
+ if any(
309
+ parse_design_column_name(part).base_term == focal_var for part in parts
310
+ ):
311
+ interaction_cols.append(j)
312
+
313
+ result: list[np.ndarray] = []
314
+ for level in levels:
315
+ X_cf = X.copy()
316
+
317
+ # Set focal dummy columns
318
+ for j, lvl in zip(focal_col_indices, focal_col_levels):
319
+ X_cf[:, j] = 1.0 if lvl == level else 0.0
320
+
321
+ # Recompute interaction columns involving the focal variable
322
+ for j in interaction_cols:
323
+ parts = X_names_list[j].split(":")
324
+ col_product = np.ones(N, dtype=np.float64)
325
+ for part in parts:
326
+ pinfo = parse_design_column_name(part)
327
+ if pinfo.base_term == focal_var:
328
+ # Focal component: use counterfactual indicator
329
+ if pinfo.column_type == "categorical":
330
+ col_product *= 1.0 if pinfo.level == level else 0.0
331
+ elif part in X_names_list:
332
+ # Continuous focal in interaction (rare but possible)
333
+ col_product *= X[:, X_names_list.index(part)]
334
+ elif part in X_names_list:
335
+ # Non-focal component: use original observed values
336
+ col_product *= X[:, X_names_list.index(part)]
337
+ X_cf[:, j] = col_product
338
+
339
+ result.append(X_cf)
340
+
341
+ return result
342
+
343
+
251
344
  # ---------------------------------------------------------------------------
252
345
  # Internal helpers
253
346
  # ---------------------------------------------------------------------------
@@ -189,6 +189,19 @@ def _dispatch_marginal_core(
189
189
  if varying not in ("exclude", "include"):
190
190
  raise ValueError(f"varying must be 'exclude' or 'include', got {varying!r}")
191
191
 
192
+ # Extract weights= kwarg with smart default
193
+ weights_raw = kwargs.pop("weights", None)
194
+ if weights_raw is not None:
195
+ weights = str(weights_raw)
196
+ if weights not in ("equal", "observed"):
197
+ raise ValueError(f"weights must be 'equal' or 'observed', got {weights!r}")
198
+ else:
199
+ # Smart default: "observed" for GLMs (non-identity link), "equal" otherwise
200
+ if spec is not None and spec.link != "identity":
201
+ weights = "observed"
202
+ else:
203
+ weights = "equal"
204
+
192
205
  # Extract inverse_transforms= kwarg (default True)
193
206
  inverse_transforms = bool(kwargs.pop("inverse_transforms", True))
194
207
 
@@ -253,6 +266,7 @@ def _dispatch_marginal_core(
253
266
  data=data,
254
267
  spec=spec,
255
268
  units=units,
269
+ weights=weights,
256
270
  resolved=resolved,
257
271
  )
258
272
 
@@ -267,6 +281,7 @@ def _dispatch_marginal_core(
267
281
  data,
268
282
  spec=spec,
269
283
  units=units,
284
+ weights=weights,
270
285
  resolved=resolved,
271
286
  focal_at_values=parsed.focal_at_values,
272
287
  contrast_ref=parsed.contrast_ref,
@@ -290,6 +305,7 @@ def _dispatch_marginal_core(
290
305
  set_categoricals=resolved.set_categoricals or None,
291
306
  spec=spec,
292
307
  units=units,
308
+ weights=weights,
293
309
  )
294
310
  # Continuous at-values (e.g. "Days@[0, 3, 6, 9]"): EMMs at
295
311
  # specific values. Forward-transform the at-values through the
@@ -341,6 +357,7 @@ def _dispatch_marginal_core(
341
357
  set_categoricals=resolved.set_categoricals or None,
342
358
  spec=spec,
343
359
  units=units,
360
+ weights=weights,
344
361
  )
345
362
 
346
363
  # Continuous focal variable → slopes
@@ -390,6 +407,7 @@ def _dispatch_marginal_core(
390
407
  spec=spec,
391
408
  formula_spec=fspec,
392
409
  units=units,
410
+ weights=weights,
393
411
  )
394
412
 
395
413
 
@@ -549,6 +567,7 @@ def _compute_emm_categorical(
549
567
  set_categoricals: dict[str, str] | None = None,
550
568
  spec: ModelSpec | None = None,
551
569
  units: str = "link",
570
+ weights: str = "equal",
552
571
  ) -> MeeState:
553
572
  """Compute EMMs for a categorical focal variable.
554
573
 
@@ -563,6 +582,7 @@ def _compute_emm_categorical(
563
582
  specific levels (e.g. ``{"Ethnicity": "Asian"}``).
564
583
  spec: ModelSpec with link info (for units="data").
565
584
  units: Scale of estimates.
585
+ weights: Averaging philosophy: ``"equal"`` or ``"observed"``.
566
586
 
567
587
  Returns:
568
588
  MeeState with grid of levels and their estimated means.
@@ -577,6 +597,7 @@ def _compute_emm_categorical(
577
597
  set_categoricals=set_categoricals,
578
598
  spec=spec,
579
599
  units=units,
600
+ weights=weights,
580
601
  )
581
602
 
582
603
 
@@ -779,6 +800,7 @@ def _compute_marginal_slope(
779
800
  spec: ModelSpec | None = None,
780
801
  formula_spec: FormulaSpec | None = None,
781
802
  units: str = "link",
803
+ weights: str = "equal",
782
804
  ) -> MeeState:
783
805
  """Compute marginal slope for a continuous focal variable.
784
806
 
@@ -797,6 +819,7 @@ def _compute_marginal_slope(
797
819
  spec: ModelSpec with family/link info.
798
820
  formula_spec: FormulaSpec with learned encoding (for finite-diff).
799
821
  units: Scale of estimates: ``"link"`` or ``"data"``.
822
+ weights: Averaging philosophy: ``"equal"`` or ``"observed"``.
800
823
 
801
824
  Returns:
802
825
  MeeState with marginal slope estimate.
@@ -825,6 +848,7 @@ def _compute_marginal_slope(
825
848
  formula_spec=formula_spec,
826
849
  data=data,
827
850
  units=units,
851
+ weights=weights,
828
852
  )
829
853
 
830
854
  # Fast path: coefficient extraction — use resolved name
@@ -1021,6 +1045,7 @@ def _compute_contrasts(
1021
1045
  *,
1022
1046
  spec: ModelSpec | None = None,
1023
1047
  units: str = "link",
1048
+ weights: str = "equal",
1024
1049
  resolved: ResolvedConditions | None = None,
1025
1050
  focal_at_values: tuple[float | str, ...] | None = None,
1026
1051
  contrast_ref: str | None = None,
@@ -1030,7 +1055,7 @@ def _compute_contrasts(
1030
1055
  First computes EMMs, then applies the requested contrast matrix.
1031
1056
  When ``resolved`` contains grid conditions, computes crossed EMMs and
1032
1057
  applies grouped contrasts. When ``focal_at_values`` is provided
1033
- (e.g. from ``pairwise(cyl[4, 8])``), contrasts are computed only
1058
+ (e.g. from ``pairwise(cyl@[4, 8])``), contrasts are computed only
1034
1059
  over the requested subset of levels.
1035
1060
 
1036
1061
  Args:
@@ -1043,6 +1068,7 @@ def _compute_contrasts(
1043
1068
  data: Model data for level extraction.
1044
1069
  spec: ModelSpec with link info (for units="data").
1045
1070
  units: Scale of estimates.
1071
+ weights: Averaging philosophy: ``"equal"`` or ``"observed"``.
1046
1072
  resolved: Resolved conditions for conditioning.
1047
1073
  focal_at_values: Optional subset of levels from at-spec syntax
1048
1074
  (e.g. ``pairwise(cyl@[4, 8])``).
@@ -1113,6 +1139,7 @@ def _compute_contrasts(
1113
1139
  set_categoricals=set_cats,
1114
1140
  spec=spec,
1115
1141
  units=units,
1142
+ weights=weights,
1116
1143
  )
1117
1144
  if contrast_ref is not None:
1118
1145
  ref_idx = _resolve_ref_idx(emm_state.grid)
@@ -1130,6 +1157,7 @@ def _compute_bracket_contrasts(
1130
1157
  data: pl.DataFrame,
1131
1158
  spec: ModelSpec | None = None,
1132
1159
  units: str = "link",
1160
+ weights: str = "equal",
1133
1161
  resolved: ResolvedConditions | None = None,
1134
1162
  ) -> MeeState:
1135
1163
  """Compute bracket contrast expression for a categorical focal variable.
@@ -1146,6 +1174,7 @@ def _compute_bracket_contrasts(
1146
1174
  data: Model data DataFrame.
1147
1175
  spec: ModelSpec with link info (for units="data").
1148
1176
  units: Scale of estimates.
1177
+ weights: Averaging philosophy: ``"equal"`` or ``"observed"``.
1149
1178
  resolved: Resolved conditions for conditioning.
1150
1179
 
1151
1180
  Returns:
@@ -1179,5 +1208,6 @@ def _compute_bracket_contrasts(
1179
1208
  set_categoricals=set_cats,
1180
1209
  spec=spec,
1181
1210
  units=units,
1211
+ weights=weights,
1182
1212
  )
1183
1213
  return apply_bracket_contrasts(emm_state, contrast_expr)
@@ -16,7 +16,10 @@ import polars as pl
16
16
 
17
17
  from bossanova.internal.containers.builders.state import build_mee_state
18
18
  from bossanova.internal.containers.structs.state import MeeState
19
- from bossanova.internal.maths.design.reference import build_reference_design_matrix
19
+ from bossanova.internal.maths.design.reference import (
20
+ build_counterfactual_design_matrices,
21
+ build_reference_design_matrix,
22
+ )
20
23
  from bossanova.internal.operations.common.factors import get_factor_levels
21
24
  from bossanova.internal.operations.marginal.validation import validate_focal_var
22
25
 
@@ -41,18 +44,21 @@ def compute_emm(
41
44
  set_categoricals: dict[str, str] | None = None,
42
45
  spec: object | None = None,
43
46
  units: str = "link",
47
+ weights: str = "equal",
44
48
  ) -> MeeState:
45
49
  """Compute estimated marginal means for a categorical focal variable.
46
50
 
47
- Computes predictions at each level of the focal variable, with other
48
- covariates set to their reference values (means for continuous, 0 for
49
- non-focal categorical).
51
+ Supports two averaging philosophies via the ``weights`` parameter:
50
52
 
51
- This matches the emmeans package behavior: EMMs are model predictions
52
- at a reference grid where:
53
- - The focal variable varies across its levels
54
- - Other continuous covariates are set to their overall means
55
- - Other categorical covariates are set to the reference level (0 in dummy coding)
53
+ - ``"equal"`` (default): Balanced reference grid (emmeans-style).
54
+ Predictions at a grid where covariates are at their means.
55
+ - ``"observed"``: G-computation / counterfactual prediction. For each
56
+ focal level, sets every observation to that level and averages the
57
+ resulting predictions. Preserves the observed covariate distribution.
58
+
59
+ For linear models (identity link), both approaches give identical results.
60
+ For GLMs (non-identity link), they diverge because
61
+ ``mean(g⁻¹(Xᵢβ)) ≠ g⁻¹(mean(Xᵢ) · β)``.
56
62
 
57
63
  Args:
58
64
  bundle: DataBundle with model data and metadata. Used to extract:
@@ -70,6 +76,11 @@ def compute_emm(
70
76
  set_categoricals: Optional dict mapping non-focal categorical variable
71
77
  names to specific levels to pin them at (instead of marginalizing
72
78
  at column means). E.g. ``{"Ethnicity": "Asian"}``.
79
+ spec: ModelSpec with link/family info (needed for units="data").
80
+ units: Scale of estimates: ``"link"`` or ``"data"``.
81
+ weights: Averaging philosophy: ``"equal"`` for balanced reference
82
+ grid (emmeans-style), ``"observed"`` for g-computation over
83
+ observed data.
73
84
 
74
85
  Returns:
75
86
  MeeState with grid of levels and their estimated means.
@@ -78,17 +89,20 @@ def compute_emm(
78
89
  ValueError: If focal_var not found in bundle.factor_levels.
79
90
 
80
91
  Examples:
81
- Compute EMMs for treatment levels::
92
+ Compute EMMs with default (reference grid) averaging::
82
93
 
83
- from bossanova.internal.operations.marginal import compute_emm
84
94
  mee = compute_emm(bundle, fit, "treatment", "treatment")
85
- # mee.estimate: array([2.1, 3.4, 2.8]) for each level
86
- # mee.grid: DataFrame with "treatment" column
95
+
96
+ G-computation averaging for a logistic regression::
97
+
98
+ mee = compute_emm(bundle, fit, "treatment", "treatment",
99
+ spec=spec, units="data", weights="observed")
87
100
 
88
101
  Note:
89
- The X_ref matrix is constructed internally using bundle.X_names
90
- to identify which columns correspond to the focal variable dummies.
91
- This works for standard treatment/dummy coding.
102
+ For ``weights="observed"`` with ``units="data"`` on a non-identity
103
+ link, confidence intervals use the response-scale delta method
104
+ (symmetric CIs) rather than the link-scale back-transformation
105
+ (asymmetric CIs) used by ``weights="equal"``.
92
106
  """
93
107
  validate_focal_var(bundle, focal_var)
94
108
 
@@ -96,10 +110,8 @@ def compute_emm(
96
110
  if levels is None:
97
111
  levels = get_factor_levels(bundle, focal_var)
98
112
 
99
- # Build reference X matrix using shared utility
100
- X_means = bundle.X.mean(axis=0).copy()
101
-
102
113
  # Apply at-value overrides for conditioning
114
+ X_means = bundle.X.mean(axis=0).copy()
103
115
  if at_overrides:
104
116
  X_names_list = list(bundle.X_names)
105
117
  for var_name, value in at_overrides.items():
@@ -117,6 +129,29 @@ def compute_emm(
117
129
  stacklevel=4,
118
130
  )
119
131
 
132
+ # G-computation path: weights="observed" with non-identity link on data scale.
133
+ # For link scale or identity link, the two approaches are mathematically
134
+ # equivalent, so we fall through to the reference grid path for efficiency.
135
+ use_gcomp = (
136
+ weights == "observed"
137
+ and units == "data"
138
+ and spec is not None
139
+ and spec.link != "identity"
140
+ )
141
+
142
+ if use_gcomp:
143
+ return _compute_emm_gcomp(
144
+ bundle=bundle,
145
+ fit=fit,
146
+ focal_var=focal_var,
147
+ explore_formula=explore_formula,
148
+ levels=levels,
149
+ spec=spec,
150
+ at_overrides=at_overrides,
151
+ set_categoricals=set_categoricals,
152
+ )
153
+
154
+ # Reference grid path (weights="equal" or equivalent cases)
120
155
  X_ref = build_reference_design_matrix(
121
156
  X_names=bundle.X_names,
122
157
  focal_var=focal_var,
@@ -149,9 +184,143 @@ def compute_emm(
149
184
  L_matrix_link = X_ref
150
185
 
151
186
  # Create grid DataFrame with conditioning columns first
187
+ grid = _build_emm_grid(focal_var, levels, at_overrides, set_categoricals)
188
+
189
+ return build_mee_state(
190
+ grid=grid,
191
+ estimate=estimates,
192
+ explore_formula=explore_formula,
193
+ focal_var=focal_var,
194
+ mee_type="means",
195
+ units=units,
196
+ weights="equal",
197
+ L_matrix=L_matrix,
198
+ link=link_name,
199
+ L_matrix_link=L_matrix_link,
200
+ )
201
+
202
+
203
+ def _compute_emm_gcomp(
204
+ bundle: "DataBundle",
205
+ fit: "FitState",
206
+ focal_var: str,
207
+ explore_formula: str,
208
+ levels: list[str],
209
+ spec: object,
210
+ *,
211
+ at_overrides: dict[str, float] | None = None,
212
+ set_categoricals: dict[str, str] | None = None,
213
+ ) -> MeeState:
214
+ """G-computation EMMs: average counterfactual predictions over observed data.
215
+
216
+ For each focal level k:
217
+ 1. Build counterfactual X_cf_k (all N rows, focal set to level k)
218
+ 2. η_k = X_cf_k @ β (N linear predictors)
219
+ 3. pred_k = g⁻¹(η_k) (N response-scale predictions)
220
+ 4. emm_k = mean(pred_k)
221
+ 5. L_k = mean(dμ/dη(η_k)[:, None] * X_cf_k, axis=0) (delta method)
222
+
223
+ Args:
224
+ bundle: DataBundle with model data and design matrix.
225
+ fit: FitState with fitted coefficients.
226
+ focal_var: Name of the categorical focal variable.
227
+ explore_formula: Explore formula string for metadata.
228
+ levels: Focal variable levels to compute EMMs for.
229
+ spec: ModelSpec with link function info.
230
+ at_overrides: Optional covariate overrides (applied before
231
+ counterfactual construction).
232
+ set_categoricals: Optional non-focal categorical pins.
233
+
234
+ Returns:
235
+ MeeState with g-computation estimates and response-scale L_matrix.
236
+ """
237
+ from bossanova.internal.maths.family.links import (
238
+ apply_link_inverse,
239
+ apply_link_inverse_deriv,
240
+ )
241
+
242
+ X = bundle.X
243
+
244
+ # Apply at_overrides: replace column values in X before counterfactual
245
+ if at_overrides:
246
+ X = X.copy()
247
+ X_names_list = list(bundle.X_names)
248
+ for var_name, value in at_overrides.items():
249
+ if var_name in X_names_list:
250
+ idx = X_names_list.index(var_name)
251
+ X[:, idx] = value
252
+
253
+ # Apply set_categoricals: pin non-focal categoricals in X
254
+ if set_categoricals:
255
+ from bossanova.internal.maths.design.names import parse_design_column_name
256
+
257
+ X = X if at_overrides else X.copy() # ensure we have a copy
258
+ for j, name in enumerate(bundle.X_names):
259
+ info = parse_design_column_name(name)
260
+ if info.column_type == "categorical" and info.base_term in set_categoricals:
261
+ X[:, j] = 1.0 if info.level == set_categoricals[info.base_term] else 0.0
262
+
263
+ # Build counterfactual design matrices
264
+ cf_matrices = build_counterfactual_design_matrices(
265
+ X=X,
266
+ X_names=bundle.X_names,
267
+ focal_var=focal_var,
268
+ levels=levels,
269
+ )
270
+
271
+ # Compute g-computation estimates and L_matrix for each level
272
+ n_levels = len(levels)
273
+ p = X.shape[1]
274
+ estimates = np.empty(n_levels, dtype=np.float64)
275
+ L_matrix = np.empty((n_levels, p), dtype=np.float64)
276
+
277
+ for k, X_cf in enumerate(cf_matrices):
278
+ eta_k = X_cf @ fit.coef # (N,)
279
+ pred_k = np.asarray(apply_link_inverse(spec.link, eta_k)) # (N,)
280
+ estimates[k] = pred_k.mean()
281
+
282
+ # Delta method: L_k = mean(dμ/dη(η_k)[:, None] * X_cf_k, axis=0)
283
+ d_k = np.asarray(apply_link_inverse_deriv(spec.link, eta_k)) # (N,)
284
+ L_matrix[k] = (d_k[:, np.newaxis] * X_cf).mean(axis=0)
285
+
286
+ # Build grid DataFrame
287
+ grid = _build_emm_grid(focal_var, levels, at_overrides, set_categoricals)
288
+
289
+ # G-computation: no link/L_matrix_link — inference uses response-scale
290
+ # delta method directly (symmetric CIs on response scale).
291
+ return build_mee_state(
292
+ grid=grid,
293
+ estimate=estimates,
294
+ explore_formula=explore_formula,
295
+ focal_var=focal_var,
296
+ mee_type="means",
297
+ units="data",
298
+ weights="observed",
299
+ L_matrix=L_matrix,
300
+ link=None,
301
+ L_matrix_link=None,
302
+ )
303
+
304
+
305
+ def _build_emm_grid(
306
+ focal_var: str,
307
+ levels: list[str],
308
+ at_overrides: dict[str, float] | None,
309
+ set_categoricals: dict[str, str] | None,
310
+ ) -> pl.DataFrame:
311
+ """Build the grid DataFrame for EMM results.
312
+
313
+ Args:
314
+ focal_var: Name of the focal variable.
315
+ levels: Focal variable levels.
316
+ at_overrides: Optional covariate overrides to include as columns.
317
+ set_categoricals: Optional categorical pins to include as columns.
318
+
319
+ Returns:
320
+ Polars DataFrame with condition columns first, then focal column.
321
+ """
152
322
  grid = pl.DataFrame({focal_var: levels})
153
323
 
154
- # Add set_categoricals as condition columns
155
324
  if set_categoricals:
156
325
  for cond_var, cond_level in set_categoricals.items():
157
326
  grid = grid.with_columns(pl.lit(cond_level).alias(cond_var))
@@ -165,17 +334,7 @@ def compute_emm(
165
334
  if cond_cols:
166
335
  grid = grid.select(cond_cols + [c for c in grid.columns if c not in cond_cols])
167
336
 
168
- return build_mee_state(
169
- grid=grid,
170
- estimate=estimates,
171
- explore_formula=explore_formula,
172
- focal_var=focal_var,
173
- mee_type="means",
174
- units=units,
175
- L_matrix=L_matrix,
176
- link=link_name,
177
- L_matrix_link=L_matrix_link,
178
- )
337
+ return grid
179
338
 
180
339
 
181
340
  def compute_emm_with_xref(
@@ -251,22 +251,24 @@ def compute_slopes_finite_diff(
251
251
  formula_spec: FormulaSpec,
252
252
  data: pl.DataFrame,
253
253
  units: str = "link",
254
+ weights: str = "equal",
254
255
  delta_frac: float = 0.001,
255
256
  ) -> MeeState:
256
257
  """Compute marginal slopes via centered finite differences (AME).
257
258
 
258
- Builds a reference grid (Cartesian product of categorical levels,
259
- continuous covariates at means), perturbs the focal variable by
260
- ``±delta/2``, constructs design matrices via ``evaluate_newdata``,
261
- and differences them to obtain the Average Marginal Effect.
259
+ Supports two averaging philosophies via the ``weights`` parameter:
262
260
 
263
- For ``units="data"`` on GLMs the response-scale derivative
264
- ``dμ/dη`` is applied row-wise before averaging, giving a true AME
265
- (not MEM).
261
+ - ``"equal"``: Average over a balanced reference grid (Cartesian product
262
+ of categorical levels, continuous covariates at means). Matches R's
263
+ ``emmeans::emtrends``.
264
+ - ``"observed"``: Average over actual data rows. Gives a true Average
265
+ Marginal Effect (AME) that preserves the observed covariate distribution.
266
266
 
267
- Algorithm (matches R ``emmeans::emtrends`` conventions):
267
+ For linear models (identity link), both approaches give identical results.
268
+
269
+ Algorithm:
268
270
  1. ``delta = delta_frac × range(focal_var)``
269
- 2. Build reference grid (factor Cartesian product, numerics at means)
271
+ 2. Build evaluation grid (balanced or observed data)
270
272
  3. Perturb: ``grid_plus = grid[focal + delta/2]``,
271
273
  ``grid_minus = grid[focal − delta/2]``
272
274
  4. ``X_plus, X_minus = evaluate_newdata(formula_spec, grid_*)``
@@ -283,6 +285,8 @@ def compute_slopes_finite_diff(
283
285
  formula_spec: FormulaSpec with learned encoding for ``evaluate_newdata``.
284
286
  data: Raw model data (for computing ranges and covariate means).
285
287
  units: ``"link"`` (linear predictor scale) or ``"data"`` (response).
288
+ weights: ``"equal"`` for balanced reference grid, ``"observed"``
289
+ for actual data rows.
286
290
  delta_frac: Fraction of the focal variable's range used as the
287
291
  finite-difference step size. Default ``0.001`` matches R's
288
292
  ``emmeans``.
@@ -309,8 +313,15 @@ def compute_slopes_finite_diff(
309
313
  )
310
314
  delta = delta_frac * var_range
311
315
 
312
- # --- reference grid ---------------------------------------------------
313
- grid = _build_slopes_data_grid(bundle, data)
316
+ # --- evaluation grid --------------------------------------------------
317
+ # weights="observed": use actual data rows (true AME)
318
+ # weights="equal": use balanced reference grid (emtrends-style)
319
+ if weights == "observed":
320
+ # Use actual data rows, dropping the response column
321
+ grid_cols = [c for c in data.columns if c != bundle.y_name]
322
+ grid = data.select(grid_cols)
323
+ else:
324
+ grid = _build_slopes_data_grid(bundle, data)
314
325
 
315
326
  # --- perturbed grids --------------------------------------------------
316
327
  grid_plus = grid.with_columns(
@@ -351,6 +362,7 @@ def compute_slopes_finite_diff(
351
362
  focal_var=focal_var,
352
363
  mee_type="slopes",
353
364
  units=units,
365
+ weights=weights,
354
366
  L_matrix=L_avg,
355
367
  )
356
368
 
@@ -1247,6 +1247,10 @@ class model:
1247
1247
  use transformed-scale names/values directly.
1248
1248
  - ``units`` (str): ``"link"`` or ``"data"``.
1249
1249
  - ``varying`` (str): ``"exclude"`` or ``"include"``.
1250
+ - ``weights`` (str): ``"equal"`` for emmeans-style balanced
1251
+ reference grid, ``"observed"`` for g-computation over
1252
+ observed data. Default is smart: ``"observed"`` for GLMs
1253
+ (non-identity link), ``"equal"`` for linear models.
1250
1254
 
1251
1255
  Returns:
1252
1256
  self: For method chaining. Results in ``.effects``.