replay-rec 0.20.3rc0__tar.gz → 0.21.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/PKG-INFO +3 -3
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/README.md +1 -1
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/pyproject.toml +33 -23
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/__init__.py +1 -1
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/dataset.py +11 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/nn/__init__.py +3 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/__init__.py +22 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/collate.py +29 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/constants/batches.py +8 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/constants/device.py +3 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/constants/metadata.py +5 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/indexing.py +123 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/masking.py +20 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/impl/utils.py +17 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/info/partitioning.py +132 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/info/replicas.py +67 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/info/worker_info.py +43 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/iterator.py +61 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/parquet_module.py +178 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay_rec-0.21.0rc0/replay/data/nn/parquet/utils/compute_length.py +66 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/nn/schema.py +12 -14
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/nn/sequence_tokenizer.py +5 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/nn/sequential_dataset.py +4 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/nn/torch_sequential_dataset.py +5 -0
- replay_rec-0.21.0rc0/replay/data/utils/batching.py +69 -0
- replay_rec-0.21.0rc0/replay/data/utils/typing/dtype.py +65 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/extensions/spark_custom_models/als_extension.py +1 -1
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/torch_metrics_builder.py +20 -14
- replay_rec-0.21.0rc0/replay/models/extensions/__init__.py +0 -0
- replay_rec-0.21.0rc0/replay/models/extensions/ann/__init__.py +0 -0
- replay_rec-0.21.0rc0/replay/models/extensions/ann/entities/__init__.py +0 -0
- replay_rec-0.21.0rc0/replay/models/extensions/ann/index_builders/__init__.py +0 -0
- replay_rec-0.21.0rc0/replay/models/extensions/ann/index_inferers/__init__.py +0 -0
- replay_rec-0.21.0rc0/replay/models/extensions/ann/index_stores/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/loss/sce.py +2 -7
- replay_rec-0.21.0rc0/replay/models/nn/optimizer_utils/__init__.py +9 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/bert4rec/model.py +11 -11
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/postprocessors/_base.py +5 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/sasrec/dataset.py +81 -26
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/sasrec/lightning.py +86 -24
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/sasrec/model.py +14 -9
- replay_rec-0.21.0rc0/replay/nn/__init__.py +8 -0
- replay_rec-0.21.0rc0/replay/nn/agg.py +109 -0
- replay_rec-0.21.0rc0/replay/nn/attention.py +158 -0
- replay_rec-0.21.0rc0/replay/nn/embedding.py +283 -0
- replay_rec-0.21.0rc0/replay/nn/ffn.py +135 -0
- replay_rec-0.21.0rc0/replay/nn/head.py +49 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/__init__.py +1 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/callback/__init__.py +9 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/module.py +123 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/optimizer.py +60 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/postprocessor/_base.py +51 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay_rec-0.21.0rc0/replay/nn/lightning/scheduler.py +91 -0
- replay_rec-0.21.0rc0/replay/nn/loss/__init__.py +22 -0
- replay_rec-0.21.0rc0/replay/nn/loss/base.py +197 -0
- replay_rec-0.21.0rc0/replay/nn/loss/bce.py +216 -0
- replay_rec-0.21.0rc0/replay/nn/loss/ce.py +317 -0
- replay_rec-0.21.0rc0/replay/nn/loss/login_ce.py +373 -0
- replay_rec-0.21.0rc0/replay/nn/loss/logout_ce.py +230 -0
- replay_rec-0.21.0rc0/replay/nn/mask.py +87 -0
- replay_rec-0.21.0rc0/replay/nn/normalization.py +9 -0
- replay_rec-0.21.0rc0/replay/nn/output.py +37 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/__init__.py +9 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/sasrec/__init__.py +7 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/sasrec/agg.py +53 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/sasrec/model.py +377 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/sasrec/transformer.py +107 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/twotower/__init__.py +2 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/twotower/model.py +674 -0
- replay_rec-0.21.0rc0/replay/nn/sequential/twotower/reader.py +89 -0
- replay_rec-0.21.0rc0/replay/nn/transform/__init__.py +22 -0
- replay_rec-0.21.0rc0/replay/nn/transform/copy.py +38 -0
- replay_rec-0.21.0rc0/replay/nn/transform/grouping.py +39 -0
- replay_rec-0.21.0rc0/replay/nn/transform/negative_sampling.py +182 -0
- replay_rec-0.21.0rc0/replay/nn/transform/next_token.py +100 -0
- replay_rec-0.21.0rc0/replay/nn/transform/rename.py +33 -0
- replay_rec-0.21.0rc0/replay/nn/transform/reshape.py +41 -0
- replay_rec-0.21.0rc0/replay/nn/transform/sequence_roll.py +48 -0
- replay_rec-0.21.0rc0/replay/nn/transform/template/__init__.py +2 -0
- replay_rec-0.21.0rc0/replay/nn/transform/template/sasrec.py +53 -0
- replay_rec-0.21.0rc0/replay/nn/transform/template/twotower.py +22 -0
- replay_rec-0.21.0rc0/replay/nn/transform/token_mask.py +69 -0
- replay_rec-0.21.0rc0/replay/nn/transform/trim.py +51 -0
- replay_rec-0.21.0rc0/replay/nn/utils.py +28 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/filters.py +128 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/label_encoder.py +36 -33
- replay_rec-0.21.0rc0/replay/preprocessing/utils.py +209 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/__init__.py +1 -0
- replay_rec-0.21.0rc0/replay/splitters/random_next_n_splitter.py +224 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/common.py +10 -4
- replay_rec-0.20.3rc0/replay/models/nn/optimizer_utils/__init__.py +0 -4
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/LICENSE +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/NOTICE +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/dataset_utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/dataset_utils/dataset_label_encoder.py +0 -0
- {replay_rec-0.20.3rc0/replay/experimental → replay_rec-0.21.0rc0/replay/data/nn/parquet/constants}/__init__.py +0 -0
- {replay_rec-0.20.3rc0/replay/experimental/models/dt4rec → replay_rec-0.21.0rc0/replay/data/nn/parquet/impl}/__init__.py +0 -0
- {replay_rec-0.20.3rc0/replay/experimental/models/extensions/spark_custom_models → replay_rec-0.21.0rc0/replay/data/nn/parquet/info}/__init__.py +0 -0
- {replay_rec-0.20.3rc0/replay/experimental/scenarios/two_stages → replay_rec-0.21.0rc0/replay/data/nn/parquet/utils}/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/nn/utils.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/schema.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/data/spark_schema.py +0 -0
- {replay_rec-0.20.3rc0/replay/experimental → replay_rec-0.21.0rc0/replay/data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0/replay/models/extensions → replay_rec-0.21.0rc0/replay/data/utils/typing}/__init__.py +0 -0
- {replay_rec-0.20.3rc0/replay/models/extensions/ann → replay_rec-0.21.0rc0/replay/experimental}/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/base_metric.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/coverage.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/experiment.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/hitrate.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/map.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/mrr.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/ncis_precision.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/ndcg.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/precision.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/recall.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/rocauc.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/surprisal.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/metrics/unexpectedness.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/admm_slim.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/base_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/base_torch_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/cql.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/ddpg.py +0 -0
- {replay_rec-0.20.3rc0/replay/models/extensions/ann/entities → replay_rec-0.21.0rc0/replay/experimental/models/dt4rec}/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/dt4rec/dt4rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/dt4rec/gpt1.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/dt4rec/trainer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/dt4rec/utils.py +0 -0
- {replay_rec-0.20.3rc0/replay/models/extensions/ann/index_builders → replay_rec-0.21.0rc0/replay/experimental/models/extensions/spark_custom_models}/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/hierarchical_recommender.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/implicit_wrap.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/lightfm_wrap.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/mult_vae.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/neural_ts.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/neuromf.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/scala_als.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/models/u_lin_ucb.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/nn/data/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/nn/data/schema_builder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/preprocessing/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/preprocessing/data_preparator.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/preprocessing/padder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/preprocessing/sequence_generator.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/obp_wrapper/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/obp_wrapper/utils.py +0 -0
- {replay_rec-0.20.3rc0/replay/models/extensions/ann/index_inferers → replay_rec-0.21.0rc0/replay/experimental/scenarios/two_stages}/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/two_stages/reranker.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -0
- {replay_rec-0.20.3rc0/replay/models/extensions/ann/index_stores → replay_rec-0.21.0rc0/replay/experimental/utils}/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/utils/logger.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/utils/model_handler.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/experimental/utils/session_handler.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/base_metric.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/categorical_diversity.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/coverage.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/descriptors.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/experiment.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/hitrate.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/map.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/mrr.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/ndcg.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/novelty.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/offline_metrics.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/precision.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/recall.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/rocauc.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/surprisal.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/metrics/unexpectedness.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/als.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/association_rules.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/base_neighbour_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/base_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/cat_pop_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/cluster.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/common.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/ann_mixin.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/entities/base_hnsw_param.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/entities/hnswlib_param.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_builders/base_index_builder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_inferers/base_inferer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_inferers/utils.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_stores/base_index_store.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_stores/hdfs_index_store.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_stores/shared_disk_index_store.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_stores/spark_files_index_store.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/index_stores/utils.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/extensions/ann/utils.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/kl_ucb.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/knn.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/lin_ucb.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/loss/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/bert4rec/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/callbacks/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/compiled/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/postprocessors/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/nn/sequential/sasrec/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/optimization/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/optimization/optuna_mixin.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/optimization/optuna_objective.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/pop_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/query_pop_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/random_rec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/slim.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/thompson_sampling.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/ucb.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/wilson.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/models/word2vec.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/converter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/discretizer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/history_based_fp.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/preprocessing/sessionizer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/scenarios/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/scenarios/fallback.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/base_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/cold_user_random_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/k_folds.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/last_n_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/new_users_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/random_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/ratio_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/time_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/splitters/two_stage_splitter.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/dataframe_bucketizer.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/distributions.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/model_handler.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/session_handler.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/spark_utils.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/time.py +0 -0
- {replay_rec-0.20.3rc0 → replay_rec-0.21.0rc0}/replay/utils/types.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.21.0rc0
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
License-File: LICENSE
|
|
@@ -30,7 +30,7 @@ Requires-Dist: sb-obp (>=0.5.10,<0.6)
|
|
|
30
30
|
Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
|
|
31
31
|
Requires-Dist: scipy (>=1.8.1,<2.0.0)
|
|
32
32
|
Requires-Dist: setuptools
|
|
33
|
-
Requires-Dist: torch (>=1.8,<
|
|
33
|
+
Requires-Dist: torch (>=1.8,<3.0.0)
|
|
34
34
|
Requires-Dist: tqdm (>=4.67,<5)
|
|
35
35
|
Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
|
|
36
36
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
@@ -231,7 +231,7 @@ pip install optuna
|
|
|
231
231
|
|
|
232
232
|
2) Model compilation via OpenVINO:
|
|
233
233
|
```bash
|
|
234
|
-
pip install openvino onnx
|
|
234
|
+
pip install openvino onnx onnxscript
|
|
235
235
|
```
|
|
236
236
|
|
|
237
237
|
3) Vector database and hierarchical search support:
|
|
@@ -1,11 +1,36 @@
|
|
|
1
1
|
[build-system]
|
|
2
2
|
requires = [
|
|
3
|
-
"poetry-core>=2.
|
|
3
|
+
"poetry-core>=2.2.1",
|
|
4
4
|
"poetry-dynamic-versioning>=1.0.0,<2.0.0",
|
|
5
5
|
"setuptools",
|
|
6
6
|
]
|
|
7
7
|
build-backend = "poetry_dynamic_versioning.backend"
|
|
8
8
|
|
|
9
|
+
[dependency-groups]
|
|
10
|
+
dev = [
|
|
11
|
+
"coverage-conditional-plugin (>=0.9, <1)",
|
|
12
|
+
"jupyter (>=1.0, <1.1)",
|
|
13
|
+
"jupyterlab (>=3.6, <4)",
|
|
14
|
+
"pyarrow-stubs",
|
|
15
|
+
"pytest (>=7.1.0)",
|
|
16
|
+
"pytest-mock (>3.15, <4.0)",
|
|
17
|
+
"pytest-cov (>=3.0)",
|
|
18
|
+
"statsmodels (>=0.14, <0.15)",
|
|
19
|
+
"black (>=23.3.0)",
|
|
20
|
+
"ruff (>=0.0.261)",
|
|
21
|
+
"hypothesis",
|
|
22
|
+
"toml-sort (>=0.23, <0.24)",
|
|
23
|
+
"sphinx (==5.3.0)",
|
|
24
|
+
"sphinx-rtd-theme (==1.2.2)",
|
|
25
|
+
"sphinx-autodoc-typehints (==1.23.0)",
|
|
26
|
+
"sphinx-enum-extend (==0.1.3)",
|
|
27
|
+
"myst-parser (==1.0.0)",
|
|
28
|
+
"ghp-import (==2.1.0)",
|
|
29
|
+
"docutils (==0.16)",
|
|
30
|
+
"data-science-types (==0.2.23)",
|
|
31
|
+
"filelock (>=3.14, <3.15)",
|
|
32
|
+
]
|
|
33
|
+
|
|
9
34
|
[project]
|
|
10
35
|
name = "replay-rec"
|
|
11
36
|
license = "Apache-2.0"
|
|
@@ -40,7 +65,7 @@ dependencies = [
|
|
|
40
65
|
"scikit-learn (>=1.6.1,<1.7.0)",
|
|
41
66
|
"pyarrow (<22.0)",
|
|
42
67
|
"tqdm (>=4.67,<5)",
|
|
43
|
-
"torch (>=1.8,<
|
|
68
|
+
"torch (>=1.8,<3.0.0)",
|
|
44
69
|
"lightning (>=2.0.2,<=2.4.0)",
|
|
45
70
|
"pytorch-optimizer (>=3.8.0,<4)",
|
|
46
71
|
"lightautoml (>=0.4.1,<0.5)",
|
|
@@ -52,7 +77,7 @@ dependencies = [
|
|
|
52
77
|
"psutil (<=7.0.0)",
|
|
53
78
|
]
|
|
54
79
|
dynamic = ["dependencies"]
|
|
55
|
-
version = "0.
|
|
80
|
+
version = "0.21.0.preview"
|
|
56
81
|
|
|
57
82
|
[project.urls]
|
|
58
83
|
homepage = "https://sb-ai-lab.github.io/RePlay/"
|
|
@@ -68,29 +93,14 @@ exclude = [
|
|
|
68
93
|
"replay/conftest.py",
|
|
69
94
|
]
|
|
70
95
|
|
|
71
|
-
[tool.poetry.
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
pytest = ">=7.1.0"
|
|
76
|
-
pytest-cov = ">=3.0.0"
|
|
77
|
-
statsmodels = "~0.14.0"
|
|
78
|
-
black = ">=23.3.0"
|
|
79
|
-
ruff = ">=0.0.261"
|
|
80
|
-
toml-sort = "^0.23.0"
|
|
81
|
-
sphinx = "5.3.0"
|
|
82
|
-
sphinx-rtd-theme = "1.2.2"
|
|
83
|
-
sphinx-autodoc-typehints = "1.23.0"
|
|
84
|
-
sphinx-enum-extend = "0.1.3"
|
|
85
|
-
myst-parser = "1.0.0"
|
|
86
|
-
ghp-import = "2.1.0"
|
|
87
|
-
docutils = "0.16"
|
|
88
|
-
data-science-types = "0.2.23"
|
|
89
|
-
filelock = "~3.14.0"
|
|
96
|
+
[[tool.poetry.source]]
|
|
97
|
+
name = "torch-cpu-mirror"
|
|
98
|
+
url = "https://download.pytorch.org/whl/cpu"
|
|
99
|
+
priority = "explicit"
|
|
90
100
|
|
|
91
101
|
[tool.poetry-dynamic-versioning]
|
|
92
102
|
enable = false
|
|
93
|
-
format-jinja = """0.
|
|
103
|
+
format-jinja = """0.21.0{{ env['PACKAGE_SUFFIX'] }}"""
|
|
94
104
|
vcs = "git"
|
|
95
105
|
|
|
96
106
|
[tool.ruff]
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
+
import warnings
|
|
8
9
|
from collections.abc import Iterable, Sequence
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from typing import Callable, Optional, Union
|
|
@@ -45,6 +46,7 @@ class Dataset:
|
|
|
45
46
|
):
|
|
46
47
|
"""
|
|
47
48
|
:param feature_schema: mapping of columns names and feature infos.
|
|
49
|
+
All features not specified in the schema will be assumed numerical by default.
|
|
48
50
|
:param interactions: dataframe with interactions.
|
|
49
51
|
:param query_features: dataframe with query features,
|
|
50
52
|
defaults: ```None```.
|
|
@@ -498,6 +500,15 @@ class Dataset:
|
|
|
498
500
|
source=FeatureSource.QUERY_FEATURES,
|
|
499
501
|
feature_schema=updated_feature_schema,
|
|
500
502
|
)
|
|
503
|
+
|
|
504
|
+
if filled_features:
|
|
505
|
+
msg = (
|
|
506
|
+
"The following features are present in the dataset but have not been specified "
|
|
507
|
+
f"by the feature schema: {[(info.column, info.feature_source.value) for info in filled_features]}. "
|
|
508
|
+
"These features will be interpreted as NUMERICAL."
|
|
509
|
+
)
|
|
510
|
+
warnings.warn(msg, stacklevel=2)
|
|
511
|
+
|
|
501
512
|
return FeatureSchema(features_list=features_list + filled_features)
|
|
502
513
|
|
|
503
514
|
def _fill_unlabeled_features_sources(self, feature_schema: FeatureSchema) -> list[FeatureInfo]:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from replay.utils import TORCH_AVAILABLE
|
|
2
2
|
|
|
3
3
|
if TORCH_AVAILABLE:
|
|
4
|
+
from .parquet import ParquetDataset, ParquetModule
|
|
4
5
|
from .schema import MutableTensorMap, TensorFeatureInfo, TensorFeatureSource, TensorMap, TensorSchema
|
|
5
6
|
from .sequence_tokenizer import SequenceTokenizer
|
|
6
7
|
from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
|
|
@@ -18,6 +19,8 @@ if TORCH_AVAILABLE:
|
|
|
18
19
|
"DEFAULT_TRAIN_PADDING_VALUE",
|
|
19
20
|
"MutableTensorMap",
|
|
20
21
|
"PandasSequentialDataset",
|
|
22
|
+
"ParquetDataset",
|
|
23
|
+
"ParquetModule",
|
|
21
24
|
"PolarsSequentialDataset",
|
|
22
25
|
"SequenceTokenizer",
|
|
23
26
|
"SequentialDataset",
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implementation of the ``ParquetDataset`` and its internals.
|
|
3
|
+
|
|
4
|
+
``ParquetDataset`` is combination of PyTorch-compatible dataset and sampler which enables
|
|
5
|
+
training and inference of models on datasets of any arbitrary size by leveraging PyArrow
|
|
6
|
+
Datasets to perform batch-wise reading and processing of data from disk.
|
|
7
|
+
|
|
8
|
+
``ParquetDataset`` includes support for Pytorch's distributed training framework as well as
|
|
9
|
+
access to remotely stored data via PyArrow's filesystem configs.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .info.replicas import DEFAULT_REPLICAS_INFO, ReplicasInfo, ReplicasInfoProtocol
|
|
13
|
+
from .parquet_dataset import ParquetDataset
|
|
14
|
+
from .parquet_module import ParquetModule
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DEFAULT_REPLICAS_INFO",
|
|
18
|
+
"ParquetDataset",
|
|
19
|
+
"ParquetModule",
|
|
20
|
+
"ReplicasInfo",
|
|
21
|
+
"ReplicasInfoProtocol",
|
|
22
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralValue
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def dict_collate(batch: Sequence[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
|
|
9
|
+
"""Simple collate function that converts a dict of values into a tensor dict."""
|
|
10
|
+
return {k: torch.cat([d[k] for d in batch], dim=0) for k in batch[0]}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def general_collate(batch: Sequence[GeneralBatch]) -> GeneralBatch:
|
|
14
|
+
"""General collate function that converts a nested dict of values into a tensor dict."""
|
|
15
|
+
result = {}
|
|
16
|
+
test_sample = batch[0]
|
|
17
|
+
|
|
18
|
+
if len(batch) == 1:
|
|
19
|
+
return test_sample
|
|
20
|
+
|
|
21
|
+
for key, test_value in test_sample.items():
|
|
22
|
+
values: Sequence[GeneralValue] = [sample[key] for sample in batch]
|
|
23
|
+
if torch.is_tensor(test_value):
|
|
24
|
+
result[key] = torch.cat(values, dim=0)
|
|
25
|
+
else:
|
|
26
|
+
assert isinstance(test_value, dict)
|
|
27
|
+
result[key] = general_collate(values)
|
|
28
|
+
|
|
29
|
+
return result
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from typing import Callable, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
GeneralValue: TypeAlias = Union[torch.Tensor, "GeneralBatch"]
|
|
7
|
+
GeneralBatch: TypeAlias = dict[str, GeneralValue]
|
|
8
|
+
GeneralCollateFn: TypeAlias = Callable[[GeneralBatch], GeneralBatch]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from typing import Callable, Optional, Protocol, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import IterableDataset
|
|
7
|
+
|
|
8
|
+
from replay.data.nn.parquet.constants.batches import GeneralBatch, GeneralCollateFn
|
|
9
|
+
from replay.data.nn.parquet.impl.masking import DEFAULT_COLLATE_FN
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_batch_size(batch: GeneralBatch, strict: bool = False) -> int:
|
|
13
|
+
"""
|
|
14
|
+
Retrieves the size of the ``batch`` object.
|
|
15
|
+
|
|
16
|
+
:param batch: Batch object.
|
|
17
|
+
:param strict: If ``True``, performs additional validation. Default: ``False``.
|
|
18
|
+
|
|
19
|
+
:raises ValueError: If size mismatch is found in the batch during a strict check.
|
|
20
|
+
|
|
21
|
+
:return: Batch size.
|
|
22
|
+
"""
|
|
23
|
+
batch_size: Optional[int] = None
|
|
24
|
+
|
|
25
|
+
for key, value in batch.items():
|
|
26
|
+
new_batch_size: int
|
|
27
|
+
|
|
28
|
+
if torch.is_tensor(value):
|
|
29
|
+
new_batch_size = value.size(0)
|
|
30
|
+
else:
|
|
31
|
+
assert isinstance(value, dict)
|
|
32
|
+
new_batch_size = get_batch_size(value, strict)
|
|
33
|
+
|
|
34
|
+
if batch_size is None:
|
|
35
|
+
batch_size = new_batch_size
|
|
36
|
+
|
|
37
|
+
if strict:
|
|
38
|
+
if batch_size != new_batch_size:
|
|
39
|
+
msg = f"Batch size mismatch {key}: {batch_size} != {new_batch_size}"
|
|
40
|
+
raise ValueError(msg)
|
|
41
|
+
else:
|
|
42
|
+
break
|
|
43
|
+
assert batch_size is not None
|
|
44
|
+
return cast(int, batch_size)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def split_batches(batch: GeneralBatch, split: int) -> tuple[GeneralBatch, GeneralBatch]:
|
|
48
|
+
left: GeneralBatch = {}
|
|
49
|
+
right: GeneralBatch = {}
|
|
50
|
+
|
|
51
|
+
for key, value in batch.items():
|
|
52
|
+
if torch.is_tensor(value):
|
|
53
|
+
sub_left = value[:split, ...]
|
|
54
|
+
sub_right = value[split:, ...]
|
|
55
|
+
else:
|
|
56
|
+
sub_left, sub_right = split_batches(value, split)
|
|
57
|
+
left[key], right[key] = sub_left, sub_right
|
|
58
|
+
|
|
59
|
+
return (left, right)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DatasetProtocol(Protocol):
|
|
63
|
+
def __iter__(self) -> Iterator[GeneralBatch]: ...
|
|
64
|
+
@property
|
|
65
|
+
def batch_size(self) -> int: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class FixedBatchSizeDataset(IterableDataset):
|
|
69
|
+
"""
|
|
70
|
+
Wrapper for arbitrary datasets that fetches batches of fixed size.
|
|
71
|
+
Concatenates batches from the wrapped dataset until it reaches the specified size.
|
|
72
|
+
The last batch may be smaller than the specified size.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
dataset: DatasetProtocol,
|
|
78
|
+
batch_size: Optional[int] = None,
|
|
79
|
+
collate_fn: GeneralCollateFn = DEFAULT_COLLATE_FN,
|
|
80
|
+
strict_checks: bool = False,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
:param dataset: An iterable object that returns batches.
|
|
84
|
+
Generally a subclass of ``torch.utils.data.IterableDataset``.
|
|
85
|
+
:param batch_size: Desired batch size. If ``None``, will search for batch size in ``dataset.batch_size``.
|
|
86
|
+
Default: ``None``.
|
|
87
|
+
:param collate_fn: Collate function for merging batches. Default: value of ``DEFAULT_COLLATE_FN``.
|
|
88
|
+
:param strict_checks: If ``True``, additional batch size checks will be performed.
|
|
89
|
+
May affect performance. Default: ``False``.
|
|
90
|
+
|
|
91
|
+
:raises ValueError: If an invalid batch size was provided.
|
|
92
|
+
"""
|
|
93
|
+
super().__init__()
|
|
94
|
+
|
|
95
|
+
self.dataset: DatasetProtocol = dataset
|
|
96
|
+
|
|
97
|
+
if batch_size is None:
|
|
98
|
+
assert hasattr(dataset, "batch_size")
|
|
99
|
+
batch_size = self.dataset.batch_size
|
|
100
|
+
|
|
101
|
+
assert isinstance(batch_size, int)
|
|
102
|
+
int_batch_size: int = cast(int, batch_size)
|
|
103
|
+
|
|
104
|
+
if int_batch_size < 1:
|
|
105
|
+
msg = f"Insufficient batch size. Got {int_batch_size=}"
|
|
106
|
+
raise ValueError(msg)
|
|
107
|
+
|
|
108
|
+
if int_batch_size < 2:
|
|
109
|
+
warnings.warn(f"Low batch size. Got {int_batch_size=}. This may cause performance issues.", stacklevel=2)
|
|
110
|
+
|
|
111
|
+
self.collate_fn: Callable = collate_fn
|
|
112
|
+
self.batch_size: int = int_batch_size
|
|
113
|
+
self.strict_checks: bool = strict_checks
|
|
114
|
+
|
|
115
|
+
def get_batch_size(self, batch: GeneralBatch) -> int:
|
|
116
|
+
return get_batch_size(batch, strict=self.strict_checks)
|
|
117
|
+
|
|
118
|
+
def __iter__(self) -> Iterator[GeneralBatch]:
|
|
119
|
+
iterator: Iterator[GeneralBatch] = iter(self.dataset)
|
|
120
|
+
|
|
121
|
+
buffer: list[GeneralBatch] = []
|
|
122
|
+
buffer_size: int = 0
|
|
123
|
+
|
|
124
|
+
while True:
|
|
125
|
+
while buffer_size < self.batch_size:
|
|
126
|
+
try:
|
|
127
|
+
batch: GeneralBatch = next(iterator)
|
|
128
|
+
size: int = self.get_batch_size(batch)
|
|
129
|
+
|
|
130
|
+
buffer.append(batch)
|
|
131
|
+
buffer_size += size
|
|
132
|
+
except StopIteration:
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
if buffer_size == 0:
|
|
136
|
+
break
|
|
137
|
+
|
|
138
|
+
joined: GeneralBatch = self.collate_fn(buffer)
|
|
139
|
+
assert buffer_size == self.get_batch_size(joined)
|
|
140
|
+
|
|
141
|
+
if self.batch_size < buffer_size:
|
|
142
|
+
left, right = split_batches(joined, self.batch_size)
|
|
143
|
+
residue: int = buffer_size - self.batch_size
|
|
144
|
+
assert residue == self.get_batch_size(right)
|
|
145
|
+
|
|
146
|
+
buffer_size = residue
|
|
147
|
+
buffer = [right]
|
|
148
|
+
|
|
149
|
+
yield left
|
|
150
|
+
else:
|
|
151
|
+
buffer_size = 0
|
|
152
|
+
buffer = []
|
|
153
|
+
|
|
154
|
+
yield joined
|
|
155
|
+
|
|
156
|
+
assert buffer_size == 0
|
|
157
|
+
assert len(buffer) == 0
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from typing import Any, Union
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
import pyarrow.compute as pc
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn.parquet.constants.device import DEFAULT_DEVICE
|
|
8
|
+
from replay.data.nn.parquet.constants.metadata import DEFAULT_PADDING
|
|
9
|
+
from replay.data.nn.parquet.metadata import (
|
|
10
|
+
Metadata,
|
|
11
|
+
get_1d_array_columns,
|
|
12
|
+
get_padding,
|
|
13
|
+
get_shape,
|
|
14
|
+
)
|
|
15
|
+
from replay.data.utils.typing.dtype import pyarrow_to_torch
|
|
16
|
+
|
|
17
|
+
from .column_protocol import OutputType
|
|
18
|
+
from .indexing import get_mask, get_offsets
|
|
19
|
+
from .utils import ensure_mutable
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Array1DColumn:
|
|
23
|
+
"""
|
|
24
|
+
Representation of a 1D array column, containing a
|
|
25
|
+
list of numbers of varying length in each of its rows.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
data: torch.Tensor,
|
|
31
|
+
lengths: torch.LongTensor,
|
|
32
|
+
shape: Union[int, list[int]],
|
|
33
|
+
padding: Any = DEFAULT_PADDING,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
:param data: A tensor containing column data.
|
|
37
|
+
:param lengths: A tensor containing lengths of each individual row array.
|
|
38
|
+
:param shape: An integer or list of integers representing the target array shapes.
|
|
39
|
+
:param padding: Padding value to use to fill null values and match target shape.
|
|
40
|
+
Default: value of ``DEFAULT_PADDING``
|
|
41
|
+
|
|
42
|
+
:raises ValueError: If the shape provided is not one-dimensional.
|
|
43
|
+
"""
|
|
44
|
+
if isinstance(shape, list) and len(shape) > 1:
|
|
45
|
+
msg = f"Array1DColumn accepts a shape of size (1,) only. Got {shape=}"
|
|
46
|
+
raise ValueError(msg)
|
|
47
|
+
|
|
48
|
+
self.padding = padding
|
|
49
|
+
self.data = data
|
|
50
|
+
self.offsets = get_offsets(lengths)
|
|
51
|
+
self.shape = shape[0] if isinstance(shape, list) else shape
|
|
52
|
+
assert self.length == torch.numel(lengths)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def length(self) -> int:
|
|
56
|
+
return torch.numel(self.offsets) - 1
|
|
57
|
+
|
|
58
|
+
def __len__(self) -> int:
|
|
59
|
+
return self.length
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def device(self) -> torch.device:
|
|
63
|
+
assert self.data.device == self.offsets.device
|
|
64
|
+
return self.offsets.device
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def dtype(self) -> torch.dtype:
|
|
68
|
+
return self.data.dtype
|
|
69
|
+
|
|
70
|
+
def __getitem__(self, indices: torch.LongTensor) -> OutputType:
|
|
71
|
+
indices = indices.to(device=self.device)
|
|
72
|
+
mask, output = get_mask(indices, self.offsets, self.shape)
|
|
73
|
+
|
|
74
|
+
# TODO: Test this for both 1d and 2d arrays. Add same check in 2d arrays
|
|
75
|
+
if self.data.numel() == 0:
|
|
76
|
+
mask = torch.zeros((indices.size(0), self.shape), dtype=torch.bool, device=self.device)
|
|
77
|
+
output = torch.ones((indices.size(0), self.shape), dtype=torch.bool, device=self.device) * self.padding
|
|
78
|
+
return mask, output
|
|
79
|
+
|
|
80
|
+
unmasked_values = torch.take(self.data, output)
|
|
81
|
+
masked_values = torch.where(mask, unmasked_values, self.padding)
|
|
82
|
+
assert masked_values.device == self.device
|
|
83
|
+
assert masked_values.dtype == self.dtype
|
|
84
|
+
return (mask, masked_values)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def to_torch(array: pa.Array, device: torch.device = DEFAULT_DEVICE) -> tuple[torch.Tensor, torch.Tensor]:
|
|
88
|
+
"""
|
|
89
|
+
Converts a PyArrow array into a PyTorch tensor.
|
|
90
|
+
|
|
91
|
+
:param array: Original PyArrow array.
|
|
92
|
+
:param device: Target device to send the resulting tensor to. Default: value of ``DEFAULT_DEVICE``.
|
|
93
|
+
|
|
94
|
+
:return: A PyTorch tensor obtained from original array.
|
|
95
|
+
"""
|
|
96
|
+
flatten = pc.list_flatten(array)
|
|
97
|
+
lengths = pc.list_value_length(array).cast(pa.int64())
|
|
98
|
+
|
|
99
|
+
# Copying to be mutable
|
|
100
|
+
flatten_torch = torch.asarray(
|
|
101
|
+
ensure_mutable(flatten.to_numpy()),
|
|
102
|
+
device=device,
|
|
103
|
+
dtype=pyarrow_to_torch(flatten.type),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Copying to be mutable
|
|
107
|
+
lengths_torch = torch.asarray(
|
|
108
|
+
ensure_mutable(lengths.to_numpy()),
|
|
109
|
+
device=device,
|
|
110
|
+
dtype=torch.int64,
|
|
111
|
+
)
|
|
112
|
+
return (lengths_torch, flatten_torch)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def to_array_1d_columns(
|
|
116
|
+
data: pa.RecordBatch,
|
|
117
|
+
metadata: Metadata,
|
|
118
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
119
|
+
) -> dict[str, Array1DColumn]:
|
|
120
|
+
"""
|
|
121
|
+
Converts a PyArrow batch of data to a set of ``Array1DColums``s.
|
|
122
|
+
This function filters only those columns matching its format from the full batch.
|
|
123
|
+
|
|
124
|
+
:param data: A PyArrow batch of column data.
|
|
125
|
+
:param metadata: Metadata containing information about columns' formats.
|
|
126
|
+
:param device: Target device to send column tensors to. Default: value of ``DEFAULT_DEVICE``
|
|
127
|
+
|
|
128
|
+
:return: A dict of tensors containing dataset's numeric columns.
|
|
129
|
+
"""
|
|
130
|
+
result: dict[str, Array1DColumn] = {}
|
|
131
|
+
|
|
132
|
+
for column_name in get_1d_array_columns(metadata):
|
|
133
|
+
lengths, torch_array = to_torch(data.column(column_name), device=device)
|
|
134
|
+
result[column_name] = Array1DColumn(
|
|
135
|
+
data=torch_array,
|
|
136
|
+
lengths=lengths,
|
|
137
|
+
padding=get_padding(metadata, column_name),
|
|
138
|
+
shape=get_shape(metadata, column_name),
|
|
139
|
+
)
|
|
140
|
+
return result
|