nshtrainer 1.3.1__py3-none-any.whl → 1.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/model/base.py CHANGED
@@ -85,12 +85,25 @@ def default_split_batched_predictions(
85
85
  """
86
86
  import torch.utils._pytree as tree
87
87
 
88
- for sample_idx, batch_idx in enumerate(batch_indices):
88
+ for i, global_idx in enumerate(batch_indices):
89
+
90
+ def _verify_and_index(x: torch.Tensor):
91
+ # Make sure dim 0 length is equal to the batch size,
92
+ # otherwise we can't index it and should prompt
93
+ # the user to implement a splitter
94
+ if x.shape[0] != len(batch_indices):
95
+ raise ValueError(
96
+ f"Batch size {x.shape[0]} does not match the number of batch indices {len(batch_indices)}. "
97
+ "Please implement a custom `split_batched_predictions` method in your LightningModuleBase class."
98
+ )
99
+
100
+ return x[i]
101
+
89
102
  # Create a dictionary for each sample
90
103
  yield IndividualSample(
91
- index=batch_idx,
92
- batch=tree.tree_map(lambda x: x[sample_idx], batch),
93
- prediction=tree.tree_map(lambda x: x[sample_idx], prediction),
104
+ index=global_idx,
105
+ batch=tree.tree_map(_verify_and_index, batch),
106
+ prediction=tree.tree_map(_verify_and_index, prediction),
94
107
  )
95
108
 
96
109
 
@@ -374,6 +374,10 @@ class TrainerConfig(C.Config):
374
374
  """Tags for the run."""
375
375
  notes: list[str] = []
376
376
  """Human readable notes for the run."""
377
+ meta: dict[str, Any] = {}
378
+ """Metadata information for the run. This is a dictionary that can be used to store any additional information
379
+ about the run. It is not used by nshtrainer, but can be useful for logging or tracking purposes.
380
+ """
377
381
 
378
382
  @property
379
383
  def full_name(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.1
3
+ Version: 1.3.3
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -118,7 +118,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSe
118
118
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
119
119
  nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
120
120
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
121
- nshtrainer/model/base.py,sha256=Pv3M3QStWQp-DnfGFsLPAmp87HHrX1NrkAa4JcyBoDk,10255
121
+ nshtrainer/model/base.py,sha256=PvTmupfGahEZME0BWqbeErDPP1VOm2Nm9JxJkO8afcc,10815
122
122
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
123
123
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
124
124
  nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
@@ -135,7 +135,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
135
135
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
136
136
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
137
137
  nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
138
- nshtrainer/trainer/_config.py,sha256=Lt9tuzxgVzVnyEFz61xbaPudfsXbKYUphOg-qMDHO8g,33203
138
+ nshtrainer/trainer/_config.py,sha256=SohR7uxANnP3xrrcW_mAjk6TuDamsW5Qdk3dlnPinDw,33457
139
139
  nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
140
140
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
141
141
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
@@ -160,6 +160,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
160
160
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
161
161
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
162
162
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
163
- nshtrainer-1.3.1.dist-info/METADATA,sha256=RCFzQ6YlNZmaYUMcLR4RMotPI3X3QXFwI6MWyN5nkjE,960
164
- nshtrainer-1.3.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
- nshtrainer-1.3.1.dist-info/RECORD,,
163
+ nshtrainer-1.3.3.dist-info/METADATA,sha256=K_xd3BrF1Yz7gGbNQgywkjysCFuwXi3GCBoQ5EaFVKY,960
164
+ nshtrainer-1.3.3.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
+ nshtrainer-1.3.3.dist-info/RECORD,,