code-loader 1.0.184.dev2__py3-none-any.whl → 1.0.184.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
code_loader/leaploader.py CHANGED
@@ -7,7 +7,7 @@ import sys
7
7
  from contextlib import redirect_stdout
8
8
  from functools import lru_cache
9
9
  from pathlib import Path
10
- from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable, Tuple
10
+ from typing import Dict, List, Iterable, Set, Union, Any, Type, Optional, Callable, Tuple
11
11
 
12
12
  import numpy as np
13
13
  import numpy.typing as npt
@@ -45,6 +45,7 @@ class LeapLoader(LeapLoaderBase):
45
45
 
46
46
  self._preprocess_result_cached = None
47
47
  self._synthetic_lookup: Dict[str, Tuple[PreprocessResponse, Any]] = {}
48
+ self._synthetic_populator: Optional[Callable[[str], None]] = None
48
49
 
49
50
  try:
50
51
  from code_loader.mixpanel_tracker import track_code_loader_loaded
@@ -188,11 +189,58 @@ class LeapLoader(LeapLoaderBase):
188
189
  for prediction_type in setup.prediction_types
189
190
  }
190
191
 
192
+ def set_synthetic_populator(self, populator: Optional[Callable[[str], None]]) -> None:
193
+ # Hook called by `_resolve_synthetic` when a synthetic sample_id is requested
194
+ # but not yet in `_synthetic_lookup`. The populator should call `run_simulation`
195
+ # for the recipe that produced sample_id, or raise if the recipe can't be
196
+ # resolved. After the call, the entire batch the recipe produced is in
197
+ # `_synthetic_lookup`, so subsequent lookups for sibling sample_ids skip the hook.
198
+ self._synthetic_populator = populator
199
+
200
+ def _user_additional_ids_set(self) -> Set[str]:
201
+ # Membership-test helper for distinguishing user-defined additional samples
202
+ # from synthetic ones. In consumer pods, the populator passes
203
+ # extend_preprocess=False to run_simulation, so additional.sample_ids stays
204
+ # uncontaminated and direct membership check is reliable.
205
+ additional = self._preprocess_result().get(DataStateEnum.additional)
206
+ if additional is None:
207
+ return set()
208
+ return set(additional.sample_ids)
209
+
210
+ def _resolve_synthetic(self, sample_id: Union[int, str],
211
+ state: Optional[DataStateEnum] = None
212
+ ) -> Optional[Tuple[PreprocessResponse, Any]]:
213
+ # Resolution flow for state==additional:
214
+ # 1. already-replayed synthetic ID (member of _synthetic_lookup) → cached synthetic flow
215
+ # 2. user-defined ID (member of preprocess[additional].sample_ids) → return None
216
+ # so caller falls through to normal preprocess flow (no recipe lookup built)
217
+ # 3. populator: builds recipe index lazily on first call, replays the missing
218
+ # recipe, populates _synthetic_lookup. Loud-raises on unresolvable IDs.
219
+ #
220
+ # Order rationale: _synthetic_lookup check sits before the user-response check
221
+ # to be robust against `_extend_additional_preprocess` mutation (producer side
222
+ # calls run_simulation with extend_preprocess=True, which appends synthetic IDs
223
+ # into preprocess[additional]). Once an ID is in _synthetic_lookup, that wins
224
+ # regardless of contamination in the preprocess response.
225
+ if not isinstance(sample_id, str):
226
+ return None
227
+ if state != DataStateEnum.additional:
228
+ return None
229
+ if sample_id in self._synthetic_lookup:
230
+ return self._synthetic_lookup[sample_id]
231
+ if sample_id in self._user_additional_ids_set():
232
+ return None
233
+ if self._synthetic_populator is None:
234
+ return None
235
+ self._synthetic_populator(sample_id)
236
+ return self._synthetic_lookup.get(sample_id)
237
+
191
238
  def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
192
239
  self.exec_script()
193
240
 
194
- if isinstance(sample_id, str) and sample_id in self._synthetic_lookup:
195
- sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
241
+ resolved = self._resolve_synthetic(sample_id, state=state)
242
+ if resolved is not None:
243
+ sim_preprocess, original_local_id = resolved
196
244
  return self._get_sample_from_preprocess(
197
245
  sim_preprocess, original_local_id, synthetic_index=sample_id
198
246
  )
@@ -397,8 +445,13 @@ class LeapLoader(LeapLoaderBase):
397
445
  result_payloads.append(test_result)
398
446
  return result_payloads
399
447
 
400
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
401
- # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
448
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0,
449
+ sample_ids=None, extend_preprocess=True):
450
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]], bool) -> Dict[str, Any]
451
+ # extend_preprocess=True (default): also extend preprocess_result[additional] with the
452
+ # new synthetic sample_ids (producer-side behavior — synthetic worker needs this for
453
+ # enumeration / data_length tracking). Pass False from consumer-side populators that
454
+ # only need _synthetic_lookup filled, so user's additional response stays uncontaminated.
402
455
  self.exec_script()
403
456
  sim = next(
404
457
  (s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
@@ -436,7 +489,8 @@ class LeapLoader(LeapLoaderBase):
436
489
  )
437
490
  for synth_id, original_local_id in zip(sample_ids, original_sample_ids):
438
491
  self._synthetic_lookup[synth_id] = (sim_preprocess, original_local_id)
439
- self._extend_additional_preprocess(list(sample_ids))
492
+ if extend_preprocess:
493
+ self._extend_additional_preprocess(list(sample_ids))
440
494
  returned_sample_ids = list(sample_ids)
441
495
  else:
442
496
  returned_sample_ids = original_sample_ids
@@ -674,9 +728,9 @@ class LeapLoader(LeapLoaderBase):
674
728
  state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
675
729
  result_agg = {}
676
730
  preprocess_result = self._preprocess_result()
677
- if (state == DataStateEnum.additional and isinstance(sample_id, str)
678
- and sample_id in self._synthetic_lookup):
679
- sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
731
+ resolved = self._resolve_synthetic(sample_id, state=state)
732
+ if resolved is not None:
733
+ sim_preprocess, original_local_id = resolved
680
734
  preprocess_state = sim_preprocess
681
735
  sample_id = original_local_id
682
736
  else:
@@ -694,9 +748,9 @@ class LeapLoader(LeapLoaderBase):
694
748
  if instance_id is None:
695
749
  return None
696
750
  preprocess_result = self._preprocess_result()
697
- if (state == DataStateEnum.additional and isinstance(sample_id, str)
698
- and sample_id in self._synthetic_lookup):
699
- sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
751
+ resolved = self._resolve_synthetic(sample_id, state=state)
752
+ if resolved is not None:
753
+ sim_preprocess, original_local_id = resolved
700
754
  preprocess_state = sim_preprocess
701
755
  sample_id = original_local_id
702
756
  else:
@@ -760,9 +814,9 @@ class LeapLoader(LeapLoaderBase):
760
814
  result_agg = {}
761
815
  is_none = {}
762
816
  preprocess_result = self._preprocess_result()
763
- if (state == DataStateEnum.additional and isinstance(sample_id, str)
764
- and sample_id in self._synthetic_lookup):
765
- sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
817
+ resolved = self._resolve_synthetic(sample_id, state=state)
818
+ if resolved is not None:
819
+ sim_preprocess, original_local_id = resolved
766
820
  preprocess_state = sim_preprocess
767
821
  sample_id = original_local_id
768
822
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.184.dev2
3
+ Version: 1.0.184.dev4
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -23,7 +23,7 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
23
23
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
24
24
  code_loader/inner_leap_binder/leapbinder.py,sha256=XLYYcV50qjMvoC1S6WW0tLBch_0g5gl1UyHiVSWYbvg,40491
25
25
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YlK51b5Ryo396f7tOA7Ole3vYCLs3f5ZLm2qDQ9K1NE,105781
26
- code_loader/leaploader.py,sha256=70enWfzlCfp4OkWQYMIXJ__V7aoJQfgiKSHrikRnpTc,41600
26
+ code_loader/leaploader.py,sha256=K9Q5kiCZ_A2GkS5qOEantAMvWGaz3bmO1X8buXDSpgg,44682
27
27
  code_loader/leaploaderbase.py,sha256=l36qDA00GhZEG5NLKpEtAXgWJA-UQQIhNFGxywK7mUA,6530
28
28
  code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
29
29
  code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -32,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
32
32
  code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
33
33
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
35
- code_loader-1.0.184.dev2.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
- code_loader-1.0.184.dev2.dist-info/METADATA,sha256=ye_o6x8GQGs0cQsHnOIh27A8TRnCRq64p_6XwslFsv4,1095
37
- code_loader-1.0.184.dev2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
- code_loader-1.0.184.dev2.dist-info/RECORD,,
35
+ code_loader-1.0.184.dev4.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
+ code_loader-1.0.184.dev4.dist-info/METADATA,sha256=t_BqhlREnawOv6Y-bEqGgefwtLGwg6hmVvdYIaKnzLI,1095
37
+ code_loader-1.0.184.dev4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
+ code_loader-1.0.184.dev4.dist-info/RECORD,,