pysp-learn 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pysp/__init__.py +1 -0
- pysp/arithmetic.py +103 -0
- pysp/data/__init__.py +20 -0
- pysp/data/dataframe.py +75 -0
- pysp/data/graph_data.py +257 -0
- pysp/data/rdd_sampler.py +46 -0
- pysp/doe/__init__.py +91 -0
- pysp/doe/bayesopt.py +333 -0
- pysp/doe/constrained.py +210 -0
- pysp/doe/designs.py +182 -0
- pysp/doe/multiobjective.py +128 -0
- pysp/doe/optimal.py +236 -0
- pysp/doe/optimizer.py +170 -0
- pysp/engines/__init__.py +119 -0
- pysp/engines/base.py +87 -0
- pysp/engines/numpy_engine.py +107 -0
- pysp/engines/precision.py +170 -0
- pysp/engines/symbolic_engine.py +529 -0
- pysp/engines/symbolic_export.py +243 -0
- pysp/engines/torch_engine.py +267 -0
- pysp/infer/__init__.py +560 -0
- pysp/infer/backends.py +127 -0
- pysp/infer/diagnostics.py +96 -0
- pysp/models/__init__.py +78 -0
- pysp/models/dependence.py +193 -0
- pysp/models/dpm.py +326 -0
- pysp/models/gaussian_process.py +238 -0
- pysp/models/grammar.py +200 -0
- pysp/models/knowledge_graph.py +198 -0
- pysp/models/neural.py +320 -0
- pysp/models/pomdp.py +318 -0
- pysp/models/random_graph.py +314 -0
- pysp/planner.py +1847 -0
- pysp/ppl/__init__.py +837 -0
- pysp/ppl/autograd.py +537 -0
- pysp/ppl/benchmark.py +152 -0
- pysp/ppl/benchmark_vs.py +254 -0
- pysp/ppl/core.py +1636 -0
- pysp/ppl/diagnostics.py +194 -0
- pysp/ppl/dynamics.py +198 -0
- pysp/ppl/inference.py +1705 -0
- pysp/ppl/pde.py +341 -0
- pysp/ppl/regression.py +529 -0
- pysp/ppl/statespace.py +96 -0
- pysp/ppl/training_data.py +138 -0
- pysp/ppl/vmp.py +490 -0
- pysp/stats/__init__.py +1879 -0
- pysp/stats/bayes/__init__.py +1 -0
- pysp/stats/bayes/catdirichlet.py +173 -0
- pysp/stats/bayes/dirichlet.py +985 -0
- pysp/stats/bayes/dpm.py +743 -0
- pysp/stats/bayes/hdpm.py +666 -0
- pysp/stats/bayes/mvngamma.py +211 -0
- pysp/stats/bayes/normgamma.py +195 -0
- pysp/stats/bayes/normwishart.py +232 -0
- pysp/stats/bayes/pitman_yor.py +372 -0
- pysp/stats/bayes/symdirichlet.py +141 -0
- pysp/stats/combinator/__init__.py +1 -0
- pysp/stats/combinator/censored.py +286 -0
- pysp/stats/combinator/composite.py +1075 -0
- pysp/stats/combinator/conditional.py +1487 -0
- pysp/stats/combinator/exponential_tilt.py +547 -0
- pysp/stats/combinator/finite_stochastic_transform.py +314 -0
- pysp/stats/combinator/ignored.py +287 -0
- pysp/stats/combinator/null_dist.py +457 -0
- pysp/stats/combinator/optional.py +773 -0
- pysp/stats/combinator/record.py +586 -0
- pysp/stats/combinator/select.py +794 -0
- pysp/stats/combinator/sequence.py +1193 -0
- pysp/stats/combinator/transform.py +558 -0
- pysp/stats/combinator/truncated.py +287 -0
- pysp/stats/combinator/weighted.py +451 -0
- pysp/stats/compute/__init__.py +1 -0
- pysp/stats/compute/backend.py +62 -0
- pysp/stats/compute/capabilities.py +110 -0
- pysp/stats/compute/declarations.py +1454 -0
- pysp/stats/compute/encoded.py +93 -0
- pysp/stats/compute/fused_kernels.py +1364 -0
- pysp/stats/compute/gradient.py +622 -0
- pysp/stats/compute/kernel.py +634 -0
- pysp/stats/compute/pdist.py +1018 -0
- pysp/stats/compute/stacked.py +650 -0
- pysp/stats/compute/torch_mixture.py +269 -0
- pysp/stats/exp_family.py +421 -0
- pysp/stats/graph/__init__.py +1 -0
- pysp/stats/graph/chow_liu_tree.py +713 -0
- pysp/stats/graph/erdos_renyi_graph.py +380 -0
- pysp/stats/graph/grammar.py +682 -0
- pysp/stats/graph/icltree.py +729 -0
- pysp/stats/graph/int_markovchain.py +1377 -0
- pysp/stats/graph/mallows.py +352 -0
- pysp/stats/graph/markov_chain.py +2006 -0
- pysp/stats/graph/markov_transform.py +896 -0
- pysp/stats/graph/matching.py +330 -0
- pysp/stats/graph/plackett_luce.py +482 -0
- pysp/stats/graph/rdpg.py +274 -0
- pysp/stats/graph/spanning_tree.py +362 -0
- pysp/stats/graph/sparse_markov_transform.py +980 -0
- pysp/stats/graph/spearman_rho.py +650 -0
- pysp/stats/graph/stochastic_block_graph.py +596 -0
- pysp/stats/latent/__init__.py +1 -0
- pysp/stats/latent/_hmm_numba_kernels.py +288 -0
- pysp/stats/latent/dirac_length.py +1023 -0
- pysp/stats/latent/heterogeneous_mixture.py +1149 -0
- pysp/stats/latent/heterogeneous_pcfg.py +1245 -0
- pysp/stats/latent/hidden_association.py +950 -0
- pysp/stats/latent/hidden_markov.py +2861 -0
- pysp/stats/latent/hidden_markov_ind_pi.py +1878 -0
- pysp/stats/latent/hmixture.py +1251 -0
- pysp/stats/latent/ibp.py +633 -0
- pysp/stats/latent/int_hidden_association.py +1522 -0
- pysp/stats/latent/int_plsi.py +1354 -0
- pysp/stats/latent/jmixture.py +1103 -0
- pysp/stats/latent/lda.py +1596 -0
- pysp/stats/latent/llda.py +1773 -0
- pysp/stats/latent/look_back_hmm.py +1232 -0
- pysp/stats/latent/mixture.py +1594 -0
- pysp/stats/latent/mvnmixture.py +840 -0
- pysp/stats/latent/ppca.py +318 -0
- pysp/stats/latent/quantized_hmm.py +924 -0
- pysp/stats/latent/segmental_hmm.py +706 -0
- pysp/stats/latent/ss_mixture.py +1070 -0
- pysp/stats/latent/tree_hmm.py +2536 -0
- pysp/stats/leaf/__init__.py +1 -0
- pysp/stats/leaf/bernoulli.py +398 -0
- pysp/stats/leaf/beta.py +419 -0
- pysp/stats/leaf/binomial.py +1073 -0
- pysp/stats/leaf/birth_death.py +332 -0
- pysp/stats/leaf/cat_multinomial.py +1023 -0
- pysp/stats/leaf/categorical.py +994 -0
- pysp/stats/leaf/exgaussian.py +328 -0
- pysp/stats/leaf/exponential.py +646 -0
- pysp/stats/leaf/gamma.py +745 -0
- pysp/stats/leaf/gaussian.py +774 -0
- pysp/stats/leaf/geometric.py +747 -0
- pysp/stats/leaf/gumbel.py +302 -0
- pysp/stats/leaf/half_normal.py +360 -0
- pysp/stats/leaf/inhomogeneous_poisson.py +303 -0
- pysp/stats/leaf/int_multinomial.py +1132 -0
- pysp/stats/leaf/int_range.py +1046 -0
- pysp/stats/leaf/int_spike.py +797 -0
- pysp/stats/leaf/inverse_gamma.py +415 -0
- pysp/stats/leaf/inverse_gaussian.py +437 -0
- pysp/stats/leaf/laplace.py +296 -0
- pysp/stats/leaf/log_gaussian.py +767 -0
- pysp/stats/leaf/logistic.py +273 -0
- pysp/stats/leaf/logseries.py +351 -0
- pysp/stats/leaf/negative_binomial.py +389 -0
- pysp/stats/leaf/pareto.py +356 -0
- pysp/stats/leaf/point_mass.py +251 -0
- pysp/stats/leaf/poisson.py +771 -0
- pysp/stats/leaf/rayleigh.py +296 -0
- pysp/stats/leaf/skellam.py +266 -0
- pysp/stats/leaf/student_t.py +299 -0
- pysp/stats/leaf/tweedie.py +294 -0
- pysp/stats/leaf/uniform.py +292 -0
- pysp/stats/leaf/von_mises.py +410 -0
- pysp/stats/leaf/weibull.py +355 -0
- pysp/stats/multivariate/__init__.py +1 -0
- pysp/stats/multivariate/dmvn.py +837 -0
- pysp/stats/multivariate/mvn.py +907 -0
- pysp/stats/multivariate/mvt.py +418 -0
- pysp/stats/multivariate/vmf.py +760 -0
- pysp/stats/sets/__init__.py +1 -0
- pysp/stats/sets/int_edit_setdist.py +997 -0
- pysp/stats/sets/int_edit_stepsetdist.py +987 -0
- pysp/stats/sets/int_setdist.py +540 -0
- pysp/stats/sets/setdist.py +900 -0
- pysp/tests/api_naming_aliases_test.py +227 -0
- pysp/tests/auto_precision_test.py +88 -0
- pysp/tests/automatic_gof_test.py +41 -0
- pysp/tests/automatic_lognormal_test.py +64 -0
- pysp/tests/automatic_mixture_test.py +37 -0
- pysp/tests/automatic_model_weights_test.py +49 -0
- pysp/tests/automatic_scientific_test.py +421 -0
- pysp/tests/automatic_studentt_test.py +39 -0
- pysp/tests/automatic_test.py +212 -0
- pysp/tests/backend_scoring_test.py +1828 -0
- pysp/tests/base_dist_test.py +445 -0
- pysp/tests/bayes_streaming_test.py +387 -0
- pysp/tests/bayes_test.py +879 -0
- pysp/tests/birth_death_test.py +57 -0
- pysp/tests/categorical_expfamily_test.py +86 -0
- pysp/tests/categorical_test.py +52 -0
- pysp/tests/censored_test.py +100 -0
- pysp/tests/chow_liu_tree_test.py +209 -0
- pysp/tests/compute_kernel_test.py +767 -0
- pysp/tests/compute_metadata_test.py +1830 -0
- pysp/tests/conftest.py +128 -0
- pysp/tests/continuous_cdf_test.py +147 -0
- pysp/tests/coupled_multiset_enum_test.py +224 -0
- pysp/tests/dask_encoded_data_test.py +174 -0
- pysp/tests/dataframe_adapter_test.py +169 -0
- pysp/tests/density_rank_test.py +302 -0
- pysp/tests/dirac_length_engine_test.py +40 -0
- pysp/tests/distribution_additions_test.py +418 -0
- pysp/tests/doe_bayesopt_test.py +201 -0
- pysp/tests/doe_constrained_test.py +93 -0
- pysp/tests/doe_designs_test.py +134 -0
- pysp/tests/doe_multiobjective_test.py +66 -0
- pysp/tests/doe_optimal_test.py +132 -0
- pysp/tests/doe_optimizer_test.py +108 -0
- pysp/tests/em_nonfinite_guard_test.py +109 -0
- pysp/tests/em_strategies_test.py +331 -0
- pysp/tests/encoded_data_backend_registry_test.py +66 -0
- pysp/tests/engine_accumulate_parity_test.py +90 -0
- pysp/tests/engine_test.py +205 -0
- pysp/tests/enumeration_test.py +591 -0
- pysp/tests/enumerator_coverage_test.py +173 -0
- pysp/tests/estimator_stability_test.py +104 -0
- pysp/tests/exgaussian_test.py +112 -0
- pysp/tests/exp_family_fisher_test.py +53 -0
- pysp/tests/exp_family_test.py +236 -0
- pysp/tests/exponential_tilt_test.py +177 -0
- pysp/tests/finite_stochastic_transform_test.py +118 -0
- pysp/tests/fisher_view_test.py +728 -0
- pysp/tests/fused_em_association_test.py +209 -0
- pysp/tests/fused_em_hmm_family_test.py +222 -0
- pysp/tests/fused_em_mixtures_test.py +188 -0
- pysp/tests/fused_em_test.py +176 -0
- pysp/tests/fused_em_variational_test.py +164 -0
- pysp/tests/gaussian_process_matern_test.py +80 -0
- pysp/tests/gaussian_process_monotone_test.py +65 -0
- pysp/tests/generated_kernel_parity_test.py +160 -0
- pysp/tests/gradient_fit_test.py +612 -0
- pysp/tests/graph_distribution_test.py +218 -0
- pysp/tests/graph_engine_test.py +51 -0
- pysp/tests/gumbel_test.py +61 -0
- pysp/tests/half_normal_test.py +67 -0
- pysp/tests/heterogeneous_pcfg_test.py +240 -0
- pysp/tests/hidden_association_engine_test.py +80 -0
- pysp/tests/hidden_association_keys_test.py +93 -0
- pysp/tests/hmixture_engine_test.py +57 -0
- pysp/tests/hmm_engine_test.py +193 -0
- pysp/tests/hmm_keys_test.py +36 -0
- pysp/tests/hmm_numba_parity_test.py +63 -0
- pysp/tests/hmm_sampler_batching_test.py +217 -0
- pysp/tests/hvis_test.py +1320 -0
- pysp/tests/ibp_test.py +117 -0
- pysp/tests/ind_pi_engine_test.py +85 -0
- pysp/tests/infer_backends_test.py +177 -0
- pysp/tests/infer_facade_test.py +186 -0
- pysp/tests/infer_parallel_chains_test.py +52 -0
- pysp/tests/inhomogeneous_poisson_test.py +57 -0
- pysp/tests/int_hidden_association_engine_test.py +63 -0
- pysp/tests/int_hidden_association_test.py +117 -0
- pysp/tests/int_plsi_engine_test.py +47 -0
- pysp/tests/integer_categorical_expfamily_test.py +71 -0
- pysp/tests/inverse_gamma_test.py +76 -0
- pysp/tests/jmixture_engine_test.py +56 -0
- pysp/tests/kernels_ext_test.py +375 -0
- pysp/tests/kernels_test.py +258 -0
- pysp/tests/key_validation_test.py +59 -0
- pysp/tests/lda_engine_test.py +64 -0
- pysp/tests/lda_len_test.py +145 -0
- pysp/tests/leaf_engine_test.py +80 -0
- pysp/tests/lightning_encoded_data_test.py +55 -0
- pysp/tests/llda_alpha_test.py +222 -0
- pysp/tests/llda_engine_test.py +90 -0
- pysp/tests/local_parallel_chunks_test.py +79 -0
- pysp/tests/logseries_test.py +72 -0
- pysp/tests/lookback_hmm_engine_test.py +106 -0
- pysp/tests/lookback_lag0_test.py +280 -0
- pysp/tests/mallows_test.py +82 -0
- pysp/tests/markov_transform_engine_test.py +85 -0
- pysp/tests/matching_test.py +70 -0
- pysp/tests/mcmc_autograd_test.py +110 -0
- pysp/tests/mcmc_convergence_test.py +125 -0
- pysp/tests/mcmc_test.py +688 -0
- pysp/tests/mixture_stability_test.py +203 -0
- pysp/tests/model_helpers_test.py +191 -0
- pysp/tests/mvt_test.py +71 -0
- pysp/tests/numerics_test.py +423 -0
- pysp/tests/nuts_mass_adaptation_test.py +67 -0
- pysp/tests/nuts_torch_test.py +77 -0
- pysp/tests/objective_resolution_test.py +84 -0
- pysp/tests/objectives_test.py +498 -0
- pysp/tests/parallel_test.py +222 -0
- pysp/tests/pareto_expfamily_test.py +59 -0
- pysp/tests/pcfg_engine_test.py +68 -0
- pysp/tests/pde_adjoint_test.py +53 -0
- pysp/tests/pde_nonlinear_test.py +66 -0
- pysp/tests/pitman_yor_test.py +101 -0
- pysp/tests/placement_test.py +391 -0
- pysp/tests/plackett_luce_partial_mle_test.py +57 -0
- pysp/tests/plackett_luce_partial_test.py +59 -0
- pysp/tests/plackett_luce_test.py +73 -0
- pysp/tests/ppca_test.py +77 -0
- pysp/tests/ppl_composite_sampling_test.py +207 -0
- pysp/tests/ppl_constraints_test.py +232 -0
- pysp/tests/ppl_core_test.py +281 -0
- pysp/tests/ppl_engine_test.py +52 -0
- pysp/tests/ppl_hetero_regression_test.py +64 -0
- pysp/tests/ppl_inference_test.py +267 -0
- pysp/tests/ppl_lda_test.py +33 -0
- pysp/tests/ppl_leaf_families_test.py +214 -0
- pysp/tests/ppl_loo_stacking_test.py +44 -0
- pysp/tests/ppl_model_comparison_test.py +132 -0
- pysp/tests/ppl_pde_test.py +130 -0
- pysp/tests/ppl_regression_test.py +125 -0
- pysp/tests/ppl_semimix_test.py +48 -0
- pysp/tests/ppl_soft_constraints_test.py +132 -0
- pysp/tests/ppl_statespace_test.py +38 -0
- pysp/tests/ppl_training_data_test.py +47 -0
- pysp/tests/ppl_vector_params_test.py +158 -0
- pysp/tests/ppl_vmp_test.py +148 -0
- pysp/tests/quantization_test.py +486 -0
- pysp/tests/quantized_hmm_test.py +339 -0
- pysp/tests/quantized_index_test.py +373 -0
- pysp/tests/random_graph_models_test.py +73 -0
- pysp/tests/ray_encoded_data_test.py +64 -0
- pysp/tests/rdpg_test.py +62 -0
- pysp/tests/sampler_accuracy_test.py +215 -0
- pysp/tests/sampler_batching_test.py +89 -0
- pysp/tests/sampler_seed_test.py +453 -0
- pysp/tests/segmental_engine_test.py +49 -0
- pysp/tests/segmental_hmm_test.py +126 -0
- pysp/tests/serialization_test.py +339 -0
- pysp/tests/skellam_test.py +78 -0
- pysp/tests/spanning_tree_test.py +93 -0
- pysp/tests/spark_encoded_data_test.py +110 -0
- pysp/tests/sparse_markov_engine_test.py +61 -0
- pysp/tests/sparse_markov_transform_test.py +117 -0
- pysp/tests/spearman_rho_test.py +72 -0
- pysp/tests/ss_mixture_engine_test.py +55 -0
- pysp/tests/stats_bayes_beta_group_test.py +144 -0
- pysp/tests/stats_bayes_dirichlet_group_test.py +214 -0
- pysp/tests/stats_bayes_dpm_test.py +331 -0
- pysp/tests/stats_bayes_gamma_group_test.py +206 -0
- pysp/tests/stats_bayes_gaussian_test.py +96 -0
- pysp/tests/stats_bayes_markov_test.py +228 -0
- pysp/tests/stats_bayes_mixture_test.py +197 -0
- pysp/tests/stats_bayes_mvgaussian_group_test.py +222 -0
- pysp/tests/stats_bayes_setdist_test.py +136 -0
- pysp/tests/stats_bayes_wrappers_test.py +299 -0
- pysp/tests/streaming_estimation_test.py +272 -0
- pysp/tests/symbolic_export_test.py +140 -0
- pysp/tests/torch_engine_ext_test.py +372 -0
- pysp/tests/torch_engine_test.py +241 -0
- pysp/tests/torchrun_encoded_data_test.py +238 -0
- pysp/tests/tree_hmm_engine_test.py +66 -0
- pysp/tests/tree_hmm_len_test.py +121 -0
- pysp/tests/tree_hmm_sampler_guard_test.py +72 -0
- pysp/tests/truncated_distribution_test.py +78 -0
- pysp/tests/truncation_bound_test.py +90 -0
- pysp/tests/tweedie_test.py +62 -0
- pysp/tests/utils_test.py +66 -0
- pysp/tests/vmf_test.py +123 -0
- pysp/tests/von_mises_test.py +79 -0
- pysp/tests/wave_bayes1_test.py +294 -0
- pysp/tests/wave_bayes2_test.py +299 -0
- pysp/tests/wave_bayes3_test.py +276 -0
- pysp/tests/wave_bayes4_test.py +188 -0
- pysp/tests/wave_core_test.py +239 -0
- pysp/tests/wave_hmmlegacy_test.py +270 -0
- pysp/tests/wave_latent_test.py +298 -0
- pysp/tests/wave_lookback_test.py +162 -0
- pysp/tests/wave_markov_test.py +294 -0
- pysp/tests/wave_multinomial_enum_test.py +221 -0
- pysp/tests/wave_mvn_test.py +267 -0
- pysp/tests/wave_select_test.py +297 -0
- pysp/tests/wave_setdist_test.py +189 -0
- pysp/tests/zero_count_estimate_test.py +123 -0
- pysp/utils/__init__.py +18 -0
- pysp/utils/aliasing.py +95 -0
- pysp/utils/automatic/__init__.py +30 -0
- pysp/utils/automatic/factories.py +449 -0
- pysp/utils/automatic/profiling.py +1926 -0
- pysp/utils/builder.py +53 -0
- pysp/utils/density_rank.py +728 -0
- pysp/utils/em.py +729 -0
- pysp/utils/enumeration.py +1134 -0
- pysp/utils/estimation.py +1056 -0
- pysp/utils/evaluation.py +133 -0
- pysp/utils/fisher.py +2143 -0
- pysp/utils/fit.py +642 -0
- pysp/utils/hvis/__init__.py +295 -0
- pysp/utils/hvis/affinity.py +725 -0
- pysp/utils/hvis/embed.py +380 -0
- pysp/utils/hvis/neighbors.py +341 -0
- pysp/utils/hvis/tsne.py +797 -0
- pysp/utils/mcmc/__init__.py +79 -0
- pysp/utils/mcmc/conjugate.py +128 -0
- pysp/utils/mcmc/gradients.py +91 -0
- pysp/utils/mcmc/nuts_numba.py +245 -0
- pysp/utils/mcmc/nuts_torch.py +249 -0
- pysp/utils/mcmc/parameter_bridge.py +467 -0
- pysp/utils/mcmc/proposals.py +380 -0
- pysp/utils/mcmc/samplers.py +854 -0
- pysp/utils/metrics.py +170 -0
- pysp/utils/objectives.py +762 -0
- pysp/utils/optional_deps.py +60 -0
- pysp/utils/optsutil.py +199 -0
- pysp/utils/parallel/__init__.py +1 -0
- pysp/utils/parallel/lightning_data.py +138 -0
- pysp/utils/parallel/mpi.py +205 -0
- pysp/utils/parallel/multiprocessing.py +284 -0
- pysp/utils/parallel/ray_data.py +159 -0
- pysp/utils/parallel/torchrun.py +250 -0
- pysp/utils/priors.py +274 -0
- pysp/utils/pvalues.py +116 -0
- pysp/utils/quantization/__init__.py +1 -0
- pysp/utils/quantization/core.py +616 -0
- pysp/utils/quantization/parallel.py +189 -0
- pysp/utils/quantization/semiring.py +290 -0
- pysp/utils/serialization.py +402 -0
- pysp/utils/special.py +214 -0
- pysp/utils/streaming.py +240 -0
- pysp/utils/vector.py +590 -0
- pysp_learn-0.2.0.dist-info/METADATA +483 -0
- pysp_learn-0.2.0.dist-info/RECORD +415 -0
- pysp_learn-0.2.0.dist-info/WHEEL +5 -0
- pysp_learn-0.2.0.dist-info/licenses/LICENSE +21 -0
- pysp_learn-0.2.0.dist-info/licenses/NOTICE +21 -0
- pysp_learn-0.2.0.dist-info/top_level.txt +1 -0
pysp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__all__ = ["stats", "utils", "models", "parallel", "src"]
|
pysp/arithmetic.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Backend-dispatched arithmetic helpers used by pysparkplug classes."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from pysp.engines import NUMPY_ENGINE, engine_of
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"asarray",
|
|
11
|
+
"zeros",
|
|
12
|
+
"empty",
|
|
13
|
+
"arange",
|
|
14
|
+
"to_numpy",
|
|
15
|
+
"log",
|
|
16
|
+
"exp",
|
|
17
|
+
"sqrt",
|
|
18
|
+
"abs",
|
|
19
|
+
"where",
|
|
20
|
+
"maximum",
|
|
21
|
+
"clip",
|
|
22
|
+
"floor",
|
|
23
|
+
"isnan",
|
|
24
|
+
"isinf",
|
|
25
|
+
"dot",
|
|
26
|
+
"matmul",
|
|
27
|
+
"cumsum",
|
|
28
|
+
"logsumexp",
|
|
29
|
+
"stack",
|
|
30
|
+
"bincount",
|
|
31
|
+
"index_add",
|
|
32
|
+
"unique",
|
|
33
|
+
"searchsorted",
|
|
34
|
+
"gammaln",
|
|
35
|
+
"digamma",
|
|
36
|
+
"betaln",
|
|
37
|
+
"erf",
|
|
38
|
+
"pi",
|
|
39
|
+
"maxint",
|
|
40
|
+
"maxrandint",
|
|
41
|
+
"one",
|
|
42
|
+
"zero",
|
|
43
|
+
"two",
|
|
44
|
+
"half",
|
|
45
|
+
"inf",
|
|
46
|
+
"eps",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _dispatch(name):
|
|
51
|
+
def fn(*args, **kwargs):
|
|
52
|
+
engine = engine_of(args, default=NUMPY_ENGINE)
|
|
53
|
+
return getattr(engine, name)(*args, **kwargs)
|
|
54
|
+
|
|
55
|
+
fn.__name__ = name
|
|
56
|
+
return fn
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
asarray = _dispatch("asarray")
|
|
60
|
+
zeros = _dispatch("zeros")
|
|
61
|
+
empty = _dispatch("empty")
|
|
62
|
+
arange = _dispatch("arange")
|
|
63
|
+
to_numpy = _dispatch("to_numpy")
|
|
64
|
+
|
|
65
|
+
log = _dispatch("log")
|
|
66
|
+
exp = _dispatch("exp")
|
|
67
|
+
sqrt = _dispatch("sqrt")
|
|
68
|
+
abs = _dispatch("abs")
|
|
69
|
+
where = _dispatch("where")
|
|
70
|
+
maximum = _dispatch("maximum")
|
|
71
|
+
clip = _dispatch("clip")
|
|
72
|
+
floor = _dispatch("floor")
|
|
73
|
+
isnan = _dispatch("isnan")
|
|
74
|
+
isinf = _dispatch("isinf")
|
|
75
|
+
|
|
76
|
+
sum = _dispatch("sum")
|
|
77
|
+
max = _dispatch("max")
|
|
78
|
+
dot = _dispatch("dot")
|
|
79
|
+
matmul = _dispatch("matmul")
|
|
80
|
+
cumsum = _dispatch("cumsum")
|
|
81
|
+
logsumexp = _dispatch("logsumexp")
|
|
82
|
+
stack = _dispatch("stack")
|
|
83
|
+
|
|
84
|
+
bincount = _dispatch("bincount")
|
|
85
|
+
index_add = _dispatch("index_add")
|
|
86
|
+
unique = _dispatch("unique")
|
|
87
|
+
searchsorted = _dispatch("searchsorted")
|
|
88
|
+
|
|
89
|
+
gammaln = _dispatch("gammaln")
|
|
90
|
+
digamma = _dispatch("digamma")
|
|
91
|
+
betaln = _dispatch("betaln")
|
|
92
|
+
erf = _dispatch("erf")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
pi = np.pi
|
|
96
|
+
maxint = 2**31 - 1
|
|
97
|
+
maxrandint = 2**31 - 1
|
|
98
|
+
one = 1.0
|
|
99
|
+
zero = 0.0
|
|
100
|
+
two = 2.0
|
|
101
|
+
half = 0.5
|
|
102
|
+
inf = float("inf")
|
|
103
|
+
eps = 1.0e-8
|
pysp/data/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Data adapters and observation representations for pysparkplug.
|
|
2
|
+
|
|
3
|
+
These are input/representation helpers (pandas DataFrame adapters, graph
|
|
4
|
+
observation encoding, Spark RDD sampling) — not probability distributions,
|
|
5
|
+
so they live outside ``pysp.stats``.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from pysp.data.dataframe import dataframe_records, seq_encode_dataframe
|
|
9
|
+
from pysp.data.graph_data import GraphDataEncoder, GraphObservation
|
|
10
|
+
from pysp.data.rdd_sampler import sample_rdd, sample_seq_as_rdd, take_sample
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"GraphDataEncoder",
|
|
14
|
+
"GraphObservation",
|
|
15
|
+
"dataframe_records",
|
|
16
|
+
"sample_rdd",
|
|
17
|
+
"sample_seq_as_rdd",
|
|
18
|
+
"seq_encode_dataframe",
|
|
19
|
+
"take_sample",
|
|
20
|
+
]
|
pysp/data/dataframe.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Pandas DataFrame adapters for the sequence-encoded stats API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pysp.stats.compute.pdist import DataSequenceEncoder, ParameterEstimator, SequenceEncodableProbabilityDistribution
|
|
9
|
+
|
|
10
|
+
FieldSpec = str | Sequence[Any] | None
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _field_source(field: Any) -> Any:
|
|
14
|
+
if isinstance(field, tuple) and len(field) == 2:
|
|
15
|
+
return field[1]
|
|
16
|
+
return field
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def dataframe_records(df: Any, fields: FieldSpec = None, as_dict: bool = False) -> list[Any]:
|
|
20
|
+
"""Convert DataFrame columns into observation records for ``seq_encode``.
|
|
21
|
+
|
|
22
|
+
A single selected field becomes scalar observations. Multiple selected
|
|
23
|
+
fields become tuple observations in the requested field order, matching the
|
|
24
|
+
data shape expected by composite distributions. When ``as_dict=True``,
|
|
25
|
+
each row is returned as a mapping keyed by the selected source field names.
|
|
26
|
+
"""
|
|
27
|
+
if fields is None:
|
|
28
|
+
field_list = list(df.columns)
|
|
29
|
+
elif isinstance(fields, str):
|
|
30
|
+
field_list = [fields]
|
|
31
|
+
else:
|
|
32
|
+
field_list = list(fields)
|
|
33
|
+
|
|
34
|
+
source_list = [_field_source(name) for name in field_list]
|
|
35
|
+
missing = [name for name in source_list if name not in df.columns]
|
|
36
|
+
if missing:
|
|
37
|
+
raise KeyError("DataFrame is missing fields: %s" % ", ".join(map(str, missing)))
|
|
38
|
+
|
|
39
|
+
if len(field_list) == 0:
|
|
40
|
+
raise ValueError("fields must select at least one DataFrame column.")
|
|
41
|
+
|
|
42
|
+
if as_dict:
|
|
43
|
+
rows = []
|
|
44
|
+
for row in df.loc[:, source_list].itertuples(index=False, name=None):
|
|
45
|
+
rows.append({name: value for name, value in zip(source_list, row)})
|
|
46
|
+
return rows
|
|
47
|
+
|
|
48
|
+
if len(field_list) == 1:
|
|
49
|
+
return df[source_list[0]].tolist()
|
|
50
|
+
|
|
51
|
+
return list(df.loc[:, source_list].itertuples(index=False, name=None))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def seq_encode_dataframe(
|
|
55
|
+
df: Any,
|
|
56
|
+
fields: FieldSpec = None,
|
|
57
|
+
encoder: DataSequenceEncoder | None = None,
|
|
58
|
+
estimator: ParameterEstimator | None = None,
|
|
59
|
+
model: SequenceEncodableProbabilityDistribution | None = None,
|
|
60
|
+
num_chunks: int = 1,
|
|
61
|
+
chunk_size: int | None = None,
|
|
62
|
+
):
|
|
63
|
+
"""Sequence-encode selected DataFrame columns with the ordinary stats API."""
|
|
64
|
+
from pysp.stats import seq_encode
|
|
65
|
+
from pysp.stats.combinator.record import RecordDistribution, RecordEstimator
|
|
66
|
+
|
|
67
|
+
if fields is None and model is not None and isinstance(model, RecordDistribution):
|
|
68
|
+
fields = tuple(zip(model.fields, model.sources))
|
|
69
|
+
elif fields is None and estimator is not None and isinstance(estimator, RecordEstimator):
|
|
70
|
+
fields = tuple(zip(estimator.fields, estimator.sources))
|
|
71
|
+
as_dict = isinstance(model, RecordDistribution) or isinstance(estimator, RecordEstimator)
|
|
72
|
+
records = dataframe_records(df, fields=fields, as_dict=as_dict)
|
|
73
|
+
return seq_encode(
|
|
74
|
+
records, encoder=encoder, estimator=estimator, model=model, num_chunks=num_chunks, chunk_size=chunk_size
|
|
75
|
+
)
|
pysp/data/graph_data.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Shared graph observation encoding helpers for stats graph distributions.
|
|
2
|
+
|
|
3
|
+
Graph observations may be square binary adjacency matrices, NetworkX-like graph
|
|
4
|
+
objects, ``(adjacency, block_assignments)`` pairs, or mappings with
|
|
5
|
+
``adjacency``/``adj`` and optional ``block_assignments``/``blocks``.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
from collections.abc import Mapping, Sequence
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import scipy.sparse as sp
|
|
17
|
+
except Exception: # pragma: no cover - scipy is a package dependency in normal use.
|
|
18
|
+
sp = None
|
|
19
|
+
|
|
20
|
+
from pysp.stats.compute.pdist import DataSequenceEncoder
|
|
21
|
+
|
|
22
|
+
_EPS = 1.0e-12
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class GraphObservation:
|
|
27
|
+
"""Canonical binary graph observation used by graph encoders."""
|
|
28
|
+
|
|
29
|
+
adjacency: np.ndarray
|
|
30
|
+
block_assignments: np.ndarray | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _clip_prob(p: float) -> float:
|
|
34
|
+
pp = float(p)
|
|
35
|
+
if not np.isfinite(pp) or pp < 0.0 or pp > 1.0:
|
|
36
|
+
raise ValueError("probabilities must be finite values in [0, 1].")
|
|
37
|
+
return float(np.clip(pp, _EPS, 1.0 - _EPS))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _bernoulli_log_likelihood(successes: float, total: float, p: float) -> float:
|
|
41
|
+
pp = _clip_prob(p)
|
|
42
|
+
return float(successes * math.log(pp) + (total - successes) * math.log1p(-pp))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _edge_indices(n: int, directed: bool, self_loops: bool):
|
|
46
|
+
if directed:
|
|
47
|
+
for i in range(n):
|
|
48
|
+
for j in range(n):
|
|
49
|
+
if self_loops or i != j:
|
|
50
|
+
yield i, j
|
|
51
|
+
else:
|
|
52
|
+
start = 0 if self_loops else 1
|
|
53
|
+
for i in range(n):
|
|
54
|
+
for j in range(i + start, n):
|
|
55
|
+
yield i, j
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _networkx_like_to_adjacency(graph: Any) -> tuple[np.ndarray, np.ndarray | None]:
|
|
59
|
+
nodes = list(graph.nodes())
|
|
60
|
+
index = {node: i for i, node in enumerate(nodes)}
|
|
61
|
+
adj = np.zeros((len(nodes), len(nodes)), dtype=np.float64)
|
|
62
|
+
directed = bool(graph.is_directed()) if hasattr(graph, "is_directed") else False
|
|
63
|
+
|
|
64
|
+
for edge in graph.edges(data=True):
|
|
65
|
+
if len(edge) == 3:
|
|
66
|
+
u, v, data = edge
|
|
67
|
+
else:
|
|
68
|
+
u, v = edge[:2]
|
|
69
|
+
data = {}
|
|
70
|
+
weight = 1.0
|
|
71
|
+
if isinstance(data, Mapping):
|
|
72
|
+
weight = data.get("weight", 1.0)
|
|
73
|
+
adj[index[u], index[v]] = weight
|
|
74
|
+
if not directed and u != v:
|
|
75
|
+
adj[index[v], index[u]] = weight
|
|
76
|
+
|
|
77
|
+
assignments = []
|
|
78
|
+
found_assignment = False
|
|
79
|
+
for node in nodes:
|
|
80
|
+
value = None
|
|
81
|
+
try:
|
|
82
|
+
attrs = graph.nodes[node]
|
|
83
|
+
if isinstance(attrs, Mapping):
|
|
84
|
+
if "block" in attrs:
|
|
85
|
+
value = attrs["block"]
|
|
86
|
+
elif "block_assignment" in attrs:
|
|
87
|
+
value = attrs["block_assignment"]
|
|
88
|
+
except Exception:
|
|
89
|
+
value = None
|
|
90
|
+
assignments.append(value)
|
|
91
|
+
found_assignment = found_assignment or value is not None
|
|
92
|
+
|
|
93
|
+
if found_assignment:
|
|
94
|
+
if any(value is None for value in assignments):
|
|
95
|
+
raise ValueError("all graph nodes must have block labels when any node has one.")
|
|
96
|
+
return adj, np.asarray(assignments, dtype=np.int64)
|
|
97
|
+
return adj, None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _edge_list_to_adjacency(edges: Sequence[Any], num_nodes: int, directed: bool) -> np.ndarray:
|
|
101
|
+
n = int(num_nodes)
|
|
102
|
+
if n < 0:
|
|
103
|
+
raise ValueError("num_nodes must be non-negative.")
|
|
104
|
+
adj = np.zeros((n, n), dtype=np.float64)
|
|
105
|
+
for edge in edges:
|
|
106
|
+
if len(edge) < 2:
|
|
107
|
+
raise ValueError("edge entries must contain at least two node indices.")
|
|
108
|
+
i = int(edge[0])
|
|
109
|
+
j = int(edge[1])
|
|
110
|
+
if i < 0 or i >= n or j < 0 or j >= n:
|
|
111
|
+
raise ValueError("edge node indices must be in [0, num_nodes).")
|
|
112
|
+
weight = float(edge[2]) if len(edge) >= 3 else 1.0
|
|
113
|
+
adj[i, j] = weight
|
|
114
|
+
if not directed and i != j:
|
|
115
|
+
adj[j, i] = weight
|
|
116
|
+
return adj
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _as_adjacency(adjacency: Any) -> np.ndarray:
|
|
120
|
+
if hasattr(adjacency, "nodes") and hasattr(adjacency, "edges"):
|
|
121
|
+
adjacency, _ = _networkx_like_to_adjacency(adjacency)
|
|
122
|
+
elif sp is not None and sp.issparse(adjacency):
|
|
123
|
+
adjacency = adjacency.toarray()
|
|
124
|
+
|
|
125
|
+
adj = np.asarray(adjacency, dtype=np.float64)
|
|
126
|
+
if adj.ndim != 2 or adj.shape[0] != adj.shape[1]:
|
|
127
|
+
raise ValueError("graph adjacency must be a square matrix.")
|
|
128
|
+
if np.any(~np.isfinite(adj)):
|
|
129
|
+
raise ValueError("graph adjacency must be finite.")
|
|
130
|
+
if np.any((adj != 0.0) & (adj != 1.0)):
|
|
131
|
+
raise ValueError("graph adjacency must contain binary values 0/1.")
|
|
132
|
+
return adj
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _as_assignments(assignments: Any | None, n: int) -> np.ndarray | None:
|
|
136
|
+
if assignments is None:
|
|
137
|
+
return None
|
|
138
|
+
rv = np.asarray(assignments, dtype=np.int64)
|
|
139
|
+
if rv.ndim != 1 or rv.shape[0] != n:
|
|
140
|
+
raise ValueError("block assignments must be a length-%d one-dimensional sequence." % n)
|
|
141
|
+
if rv.size and rv.min() < 0:
|
|
142
|
+
raise ValueError("block assignments must be non-negative integers.")
|
|
143
|
+
return rv
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _extract_observation(x: Any, directed: bool = False, fallback_assignments: Any | None = None) -> GraphObservation:
|
|
147
|
+
if isinstance(x, GraphObservation):
|
|
148
|
+
adj = _as_adjacency(x.adjacency)
|
|
149
|
+
assignments = x.block_assignments
|
|
150
|
+
elif isinstance(x, Mapping):
|
|
151
|
+
assignments = x.get("block_assignments", x.get("blocks", fallback_assignments))
|
|
152
|
+
if "adjacency" in x:
|
|
153
|
+
adj = _as_adjacency(x["adjacency"])
|
|
154
|
+
elif "adj" in x:
|
|
155
|
+
adj = _as_adjacency(x["adj"])
|
|
156
|
+
elif "graph" in x:
|
|
157
|
+
adj, graph_assignments = _networkx_like_to_adjacency(x["graph"])
|
|
158
|
+
if assignments is None:
|
|
159
|
+
assignments = graph_assignments
|
|
160
|
+
adj = _as_adjacency(adj)
|
|
161
|
+
elif "edges" in x and "num_nodes" in x:
|
|
162
|
+
adj = _edge_list_to_adjacency(x["edges"], int(x["num_nodes"]), directed=directed)
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("graph mapping must contain adjacency, adj, graph, or edges+num_nodes.")
|
|
165
|
+
elif isinstance(x, tuple) and len(x) == 2:
|
|
166
|
+
adj = _as_adjacency(x[0])
|
|
167
|
+
assignments = x[1] if x[1] is not None else fallback_assignments
|
|
168
|
+
elif isinstance(x, list) and len(x) == 2 and not np.isscalar(x[0]):
|
|
169
|
+
try:
|
|
170
|
+
adj = _as_adjacency(x[0])
|
|
171
|
+
assignments = x[1] if x[1] is not None else fallback_assignments
|
|
172
|
+
except Exception:
|
|
173
|
+
adj = _as_adjacency(x)
|
|
174
|
+
assignments = fallback_assignments
|
|
175
|
+
elif hasattr(x, "nodes") and hasattr(x, "edges"):
|
|
176
|
+
adj, graph_assignments = _networkx_like_to_adjacency(x)
|
|
177
|
+
assignments = graph_assignments if graph_assignments is not None else fallback_assignments
|
|
178
|
+
adj = _as_adjacency(adj)
|
|
179
|
+
else:
|
|
180
|
+
adj = _as_adjacency(x)
|
|
181
|
+
assignments = fallback_assignments
|
|
182
|
+
|
|
183
|
+
return GraphObservation(adj, _as_assignments(assignments, adj.shape[0]))
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _edge_counts(adj: np.ndarray, directed: bool, self_loops: bool) -> tuple[float, float]:
|
|
187
|
+
total = 0.0
|
|
188
|
+
successes = 0.0
|
|
189
|
+
for i, j in _edge_indices(adj.shape[0], directed=directed, self_loops=self_loops):
|
|
190
|
+
successes += adj[i, j]
|
|
191
|
+
total += 1.0
|
|
192
|
+
return total, successes
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _validate_block_probs(block_probs: Any) -> np.ndarray:
|
|
196
|
+
probs = np.asarray(block_probs, dtype=np.float64)
|
|
197
|
+
if probs.ndim != 2 or probs.shape[0] != probs.shape[1]:
|
|
198
|
+
raise ValueError("block_probs must be a square matrix.")
|
|
199
|
+
if probs.shape[0] == 0:
|
|
200
|
+
raise ValueError("block_probs must contain at least one block.")
|
|
201
|
+
if np.any(~np.isfinite(probs)) or np.any(probs < 0.0) or np.any(probs > 1.0):
|
|
202
|
+
raise ValueError("block probabilities must be finite and in [0, 1].")
|
|
203
|
+
return probs
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _validate_block_indices(assignments: np.ndarray, num_blocks: int) -> None:
|
|
207
|
+
if assignments.ndim != 1:
|
|
208
|
+
raise ValueError("block assignments must be a one-dimensional sequence.")
|
|
209
|
+
if assignments.size and (assignments.min() < 0 or assignments.max() >= num_blocks):
|
|
210
|
+
raise ValueError("block assignments must index block_probs.")
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _normalize_prior(block_prior: Any | None, num_blocks: int) -> np.ndarray:
|
|
214
|
+
if num_blocks <= 0:
|
|
215
|
+
raise ValueError("num_blocks must be positive.")
|
|
216
|
+
if block_prior is None:
|
|
217
|
+
return np.full(int(num_blocks), 1.0 / float(num_blocks), dtype=np.float64)
|
|
218
|
+
prior = np.asarray(block_prior, dtype=np.float64)
|
|
219
|
+
if prior.ndim != 1 or prior.shape[0] != num_blocks:
|
|
220
|
+
raise ValueError("block_prior must be a length-num_blocks vector.")
|
|
221
|
+
if np.any(~np.isfinite(prior)) or np.any(prior < 0.0) or prior.sum() <= 0.0:
|
|
222
|
+
raise ValueError("block_prior must contain non-negative finite values with positive sum.")
|
|
223
|
+
return prior / prior.sum()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class GraphDataEncoder(DataSequenceEncoder):
|
|
227
|
+
"""Encode graph observations as canonical adjacency/assignment objects."""
|
|
228
|
+
|
|
229
|
+
def __init__(self, directed: bool = False, fallback_assignments: Any | None = None) -> None:
|
|
230
|
+
self.directed = bool(directed)
|
|
231
|
+
self.fallback_assignments = (
|
|
232
|
+
None
|
|
233
|
+
if fallback_assignments is None
|
|
234
|
+
else tuple(int(u) for u in np.asarray(fallback_assignments, dtype=np.int64))
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def __str__(self) -> str:
|
|
238
|
+
return "GraphDataEncoder(directed=%s)" % repr(self.directed)
|
|
239
|
+
|
|
240
|
+
def __eq__(self, other: object) -> bool:
|
|
241
|
+
return (
|
|
242
|
+
isinstance(other, GraphDataEncoder)
|
|
243
|
+
and self.directed == other.directed
|
|
244
|
+
and self.fallback_assignments == other.fallback_assignments
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def seq_encode(self, x: Sequence[Any]) -> tuple[GraphObservation, ...]:
|
|
248
|
+
fallback = None if self.fallback_assignments is None else np.asarray(self.fallback_assignments, dtype=np.int64)
|
|
249
|
+
return tuple(_extract_observation(u, directed=self.directed, fallback_assignments=fallback) for u in x)
|
|
250
|
+
|
|
251
|
+
def nbytes(self, x: Any) -> int:
|
|
252
|
+
total = 0
|
|
253
|
+
for obs in x:
|
|
254
|
+
total += int(obs.adjacency.nbytes)
|
|
255
|
+
if obs.block_assignments is not None:
|
|
256
|
+
total += int(obs.block_assignments.nbytes)
|
|
257
|
+
return total
|
pysp/data/rdd_sampler.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from pyspark import SparkConf, SparkContext
|
|
3
|
+
except ImportError:
|
|
4
|
+
SparkContext = SparkConf = None # pip install pysparkplug[spark]
|
|
5
|
+
import pickle
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numpy.random import RandomState
|
|
10
|
+
|
|
11
|
+
from pysp.arithmetic import *
|
|
12
|
+
from pysp.arithmetic import maxrandint
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def take_sample(rdd: Any, with_replacement: bool, n: int, seed: int | None = None):
|
|
16
|
+
rng = RandomState(seed)
|
|
17
|
+
sample = rdd.zipWithUniqueId().takeSample(with_replacement, n, rng.randint(0, maxrandint))
|
|
18
|
+
sidx = np.argsort([u[1] for u in sample])
|
|
19
|
+
sample = [sample[i][0] for i in sidx]
|
|
20
|
+
sidx = np.argsort(rng.uniform(size=n))
|
|
21
|
+
return [sample[i] for i in sidx]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def sample_seq_as_rdd(sc, dist, seq_len, count_per_split, num_splits, seed=None):
|
|
25
|
+
distB = sc.broadcast(dist)
|
|
26
|
+
seeds = RandomState(seed).randint(0, maxrandint, size=num_splits)
|
|
27
|
+
|
|
28
|
+
def fmap(u):
|
|
29
|
+
ddist = distB.value
|
|
30
|
+
sampler = [ddist.sampler(seed=h) for h in u]
|
|
31
|
+
return iter([v for h in sampler for v in h.sample_seq(seq_len, size=count_per_split)])
|
|
32
|
+
|
|
33
|
+
return sc.parallelize(seeds, num_splits).mapPartitions(fmap, True)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def sample_rdd(sc, dist, count_per_split, num_splits, seed=None):
|
|
37
|
+
dd = pickle.dumps(dist, protocol=0)
|
|
38
|
+
distB = sc.broadcast(dd)
|
|
39
|
+
seeds = RandomState(seed).randint(0, maxrandint, size=num_splits)
|
|
40
|
+
|
|
41
|
+
def fmap(u):
|
|
42
|
+
ddist = pickle.loads(distB.value)
|
|
43
|
+
sampler = [ddist.sampler(seed=h) for h in u]
|
|
44
|
+
return iter([v for h in sampler for v in h.sample(size=count_per_split)])
|
|
45
|
+
|
|
46
|
+
return sc.parallelize(seeds, num_splits).mapPartitions(fmap, True)
|
pysp/doe/__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Design of experiments (DoE) for pysparkplug.
|
|
2
|
+
|
|
3
|
+
This package builds experiment designs over a bounded input space and (in later additions)
|
|
4
|
+
sequential / Bayesian-optimization loops on top of the existing GP and regression machinery.
|
|
5
|
+
|
|
6
|
+
The first surface is space-filling and classical design generators, all returning a plain
|
|
7
|
+
``(n, d)`` numpy matrix of input points scaled into the supplied per-dimension bounds:
|
|
8
|
+
|
|
9
|
+
>>> from pysp.doe import latin_hypercube
|
|
10
|
+
>>> x = latin_hypercube([(0.0, 1.0), (-2.0, 2.0)], n=8, seed=0)
|
|
11
|
+
>>> x.shape
|
|
12
|
+
(8, 2)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from pysp.doe.bayesopt import (
|
|
18
|
+
BayesOptResult,
|
|
19
|
+
available_acquisitions,
|
|
20
|
+
expected_improvement,
|
|
21
|
+
minimize,
|
|
22
|
+
probability_of_improvement,
|
|
23
|
+
propose_batch,
|
|
24
|
+
propose_next,
|
|
25
|
+
register_acquisition,
|
|
26
|
+
upper_confidence_bound,
|
|
27
|
+
)
|
|
28
|
+
from pysp.doe.constrained import (
|
|
29
|
+
ConstrainedBayesOptResult,
|
|
30
|
+
constrained_minimize,
|
|
31
|
+
probability_of_feasibility,
|
|
32
|
+
propose_next_constrained,
|
|
33
|
+
)
|
|
34
|
+
from pysp.doe.designs import (
|
|
35
|
+
Bounds,
|
|
36
|
+
full_factorial,
|
|
37
|
+
halton_design,
|
|
38
|
+
latin_hypercube,
|
|
39
|
+
maximin_latin_hypercube,
|
|
40
|
+
random_design,
|
|
41
|
+
sobol_design,
|
|
42
|
+
)
|
|
43
|
+
from pysp.doe.multiobjective import (
|
|
44
|
+
MultiObjectiveResult,
|
|
45
|
+
multi_minimize,
|
|
46
|
+
pareto_mask,
|
|
47
|
+
)
|
|
48
|
+
from pysp.doe.optimal import (
|
|
49
|
+
a_criterion,
|
|
50
|
+
available_criteria,
|
|
51
|
+
d_criterion,
|
|
52
|
+
i_criterion,
|
|
53
|
+
optimal_design,
|
|
54
|
+
polynomial_features,
|
|
55
|
+
register_criterion,
|
|
56
|
+
)
|
|
57
|
+
from pysp.doe.optimizer import BayesianOptimizer
|
|
58
|
+
|
|
59
|
+
__all__ = [
|
|
60
|
+
"Bounds",
|
|
61
|
+
"full_factorial",
|
|
62
|
+
"halton_design",
|
|
63
|
+
"latin_hypercube",
|
|
64
|
+
"maximin_latin_hypercube",
|
|
65
|
+
"random_design",
|
|
66
|
+
"sobol_design",
|
|
67
|
+
"BayesOptResult",
|
|
68
|
+
"expected_improvement",
|
|
69
|
+
"probability_of_improvement",
|
|
70
|
+
"upper_confidence_bound",
|
|
71
|
+
"register_acquisition",
|
|
72
|
+
"available_acquisitions",
|
|
73
|
+
"minimize",
|
|
74
|
+
"propose_next",
|
|
75
|
+
"propose_batch",
|
|
76
|
+
"optimal_design",
|
|
77
|
+
"polynomial_features",
|
|
78
|
+
"d_criterion",
|
|
79
|
+
"a_criterion",
|
|
80
|
+
"i_criterion",
|
|
81
|
+
"register_criterion",
|
|
82
|
+
"available_criteria",
|
|
83
|
+
"ConstrainedBayesOptResult",
|
|
84
|
+
"probability_of_feasibility",
|
|
85
|
+
"propose_next_constrained",
|
|
86
|
+
"constrained_minimize",
|
|
87
|
+
"MultiObjectiveResult",
|
|
88
|
+
"pareto_mask",
|
|
89
|
+
"multi_minimize",
|
|
90
|
+
"BayesianOptimizer",
|
|
91
|
+
]
|