nshtrainer 1.3.1__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/model/base.py CHANGED
@@ -85,12 +85,25 @@ def default_split_batched_predictions(
85
85
  """
86
86
  import torch.utils._pytree as tree
87
87
 
88
- for sample_idx, batch_idx in enumerate(batch_indices):
88
+ for i, global_idx in enumerate(batch_indices):
89
+
90
+ def _verify_and_index(x: torch.Tensor):
91
+ # Make sure dim 0 length is equal to the batch size,
92
+ # otherwise we can't index it and should prompt
93
+ # the user to implement a splitter
94
+ if x.shape[0] != len(batch_indices):
95
+ raise ValueError(
96
+ f"Batch size {x.shape[0]} does not match the number of batch indices {len(batch_indices)}. "
97
+ "Please implement a custom `split_batched_predictions` method in your LightningModuleBase class."
98
+ )
99
+
100
+ return x[i]
101
+
89
102
  # Create a dictionary for each sample
90
103
  yield IndividualSample(
91
- index=batch_idx,
92
- batch=tree.tree_map(lambda x: x[sample_idx], batch),
93
- prediction=tree.tree_map(lambda x: x[sample_idx], prediction),
104
+ index=global_idx,
105
+ batch=tree.tree_map(_verify_and_index, batch),
106
+ prediction=tree.tree_map(_verify_and_index, prediction),
94
107
  )
95
108
 
96
109
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.1
3
+ Version: 1.3.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -118,7 +118,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSe
118
118
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
119
119
  nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
120
120
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
121
- nshtrainer/model/base.py,sha256=Pv3M3QStWQp-DnfGFsLPAmp87HHrX1NrkAa4JcyBoDk,10255
121
+ nshtrainer/model/base.py,sha256=PvTmupfGahEZME0BWqbeErDPP1VOm2Nm9JxJkO8afcc,10815
122
122
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
123
123
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
124
124
  nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
@@ -160,6 +160,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
160
160
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
161
161
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
162
162
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
163
- nshtrainer-1.3.1.dist-info/METADATA,sha256=RCFzQ6YlNZmaYUMcLR4RMotPI3X3QXFwI6MWyN5nkjE,960
164
- nshtrainer-1.3.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
- nshtrainer-1.3.1.dist-info/RECORD,,
163
+ nshtrainer-1.3.2.dist-info/METADATA,sha256=XQdG9IP0N83areh70D1kM_rneikv2JQnO2VS34_MXRM,960
164
+ nshtrainer-1.3.2.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
+ nshtrainer-1.3.2.dist-info/RECORD,,