code-loader 1.0.183.dev0__py3-none-any.whl → 1.0.184.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
code_loader/leaploader.py CHANGED
@@ -44,7 +44,7 @@ class LeapLoader(LeapLoaderBase):
44
44
  super().__init__(code_path, code_entry_name)
45
45
 
46
46
  self._preprocess_result_cached = None
47
- self._synthetic_responses: Dict[str, PreprocessResponse] = {}
47
+ self._synthetic_lookup: Dict[str, Tuple[PreprocessResponse, Any]] = {}
48
48
 
49
49
  try:
50
50
  from code_loader.mixpanel_tracker import track_code_loader_loaded
@@ -191,22 +191,11 @@ class LeapLoader(LeapLoaderBase):
191
191
  def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
192
192
  self.exec_script()
193
193
 
194
- if isinstance(sample_id, str) and sample_id.startswith("synthetic_"):
195
- last_underscore = sample_id.rfind("_")
196
- if last_underscore != -1:
197
- prefix = sample_id[:last_underscore]
198
- local_idx_str = sample_id[last_underscore + 1:]
199
- sim_preprocess = self._synthetic_responses.get(prefix)
200
- if sim_preprocess is not None:
201
- try:
202
- local_idx = int(local_idx_str)
203
- original_sample_id = sim_preprocess.sample_ids[local_idx]
204
- except (ValueError, IndexError):
205
- pass
206
- else:
207
- return self._get_sample_from_preprocess(
208
- sim_preprocess, original_sample_id, synthetic_index=sample_id
209
- )
194
+ if isinstance(sample_id, str) and sample_id in self._synthetic_lookup:
195
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
196
+ return self._get_sample_from_preprocess(
197
+ sim_preprocess, original_local_id, synthetic_index=sample_id
198
+ )
210
199
 
211
200
  preprocess_result = self._preprocess_result()
212
201
  if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
@@ -408,8 +397,8 @@ class LeapLoader(LeapLoaderBase):
408
397
  result_payloads.append(test_result)
409
398
  return result_payloads
410
399
 
411
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
412
- # type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
400
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
401
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
413
402
  self.exec_script()
414
403
  sim = next(
415
404
  (s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
@@ -433,13 +422,45 @@ class LeapLoader(LeapLoaderBase):
433
422
  for handler in global_leap_binder.setup_container.inputs:
434
423
  per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
435
424
  encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
436
- if sample_id_prefix is not None:
437
- returned_sample_ids = ["{}_{}".format(sample_id_prefix, i) for i in range(len(original_sample_ids))]
438
- self._synthetic_responses[sample_id_prefix] = sim_preprocess
425
+ if sample_ids is not None:
426
+ if len(sample_ids) != len(original_sample_ids):
427
+ raise ValueError(
428
+ "sample_ids length ({}) does not match simulation output length ({})".format(
429
+ len(sample_ids), len(original_sample_ids)
430
+ )
431
+ )
432
+ for sid in sample_ids:
433
+ if not isinstance(sid, str):
434
+ raise TypeError(
435
+ "All sample_ids must be of type str. Got: {}".format(type(sid))
436
+ )
437
+ for synth_id, original_local_id in zip(sample_ids, original_sample_ids):
438
+ self._synthetic_lookup[synth_id] = (sim_preprocess, original_local_id)
439
+ self._extend_additional_preprocess(list(sample_ids))
440
+ returned_sample_ids = list(sample_ids)
439
441
  else:
440
442
  returned_sample_ids = original_sample_ids
441
443
  return {"encoded": encoded, "sample_ids": returned_sample_ids}
442
444
 
445
+ def _extend_additional_preprocess(self, new_sample_ids: List[str]) -> None:
446
+ if self._preprocess_result_cached is None:
447
+ self._preprocess_result()
448
+ additional = self._preprocess_result_cached.get(DataStateEnum.additional)
449
+ if additional is None:
450
+ placeholder = PreprocessResponse(
451
+ sample_ids=list(new_sample_ids),
452
+ data={},
453
+ state=DataStateType.additional,
454
+ sample_id_type=str,
455
+ tl_generated=True,
456
+ )
457
+ self._preprocess_result_cached[DataStateEnum.additional] = placeholder
458
+ global_leap_binder.setup_container.preprocess.data_length[DataStateType.additional] = placeholder.length
459
+ else:
460
+ additional.sample_ids = list(additional.sample_ids) + list(new_sample_ids)
461
+ additional.length = len(additional.sample_ids)
462
+ global_leap_binder.setup_container.preprocess.data_length[DataStateType.additional] = additional.length
463
+
443
464
  @staticmethod
444
465
  def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
445
466
  all_dataset_base_handlers: List[Union[DatasetBaseHandler, MetadataHandler]] = []
@@ -653,7 +674,13 @@ class LeapLoader(LeapLoaderBase):
653
674
  state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
654
675
  result_agg = {}
655
676
  preprocess_result = self._preprocess_result()
656
- preprocess_state = preprocess_result[state]
677
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
678
+ and sample_id in self._synthetic_lookup):
679
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
680
+ preprocess_state = sim_preprocess
681
+ sample_id = original_local_id
682
+ else:
683
+ preprocess_state = preprocess_result[state]
657
684
  for handler in handlers:
658
685
  handler_result = handler.function(sample_id, preprocess_state)
659
686
  handler_name = handler.name
@@ -667,7 +694,13 @@ class LeapLoader(LeapLoaderBase):
667
694
  if instance_id is None:
668
695
  return None
669
696
  preprocess_result = self._preprocess_result()
670
- preprocess_state = preprocess_result[state]
697
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
698
+ and sample_id in self._synthetic_lookup):
699
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
700
+ preprocess_state = sim_preprocess
701
+ sample_id = original_local_id
702
+ else:
703
+ preprocess_state = preprocess_result[state]
671
704
  result_agg = {}
672
705
  for handler in global_leap_binder.setup_container.instance_masks:
673
706
  handler_result = handler.function(sample_id, preprocess_state, instance_id)
@@ -727,7 +760,13 @@ class LeapLoader(LeapLoaderBase):
727
760
  result_agg = {}
728
761
  is_none = {}
729
762
  preprocess_result = self._preprocess_result()
730
- preprocess_state = preprocess_result[state]
763
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
764
+ and sample_id in self._synthetic_lookup):
765
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
766
+ preprocess_state = sim_preprocess
767
+ sample_id = original_local_id
768
+ else:
769
+ preprocess_state = preprocess_result[state]
731
770
  for handler in global_leap_binder.setup_container.metadata:
732
771
  if requested_metadata_names:
733
772
  if not is_metadata_name_starts_with_handler_name(handler):
@@ -768,17 +807,10 @@ class LeapLoader(LeapLoaderBase):
768
807
  return global_leap_binder.setup_container.custom_latent_space is not None
769
808
 
770
809
  def get_instances_data(self, state: DataStateEnum) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
771
- """
772
- This Method get the data state and returns two dictionaries that holds the mapping of the sample ids to their
773
- instances and the other way around and the sample ids array.
774
- Args:
775
- state: DataStateEnum state
776
- Returns:
777
- sample_ids_to_instance_mappings: sample id to instance mappings
778
- instance_to_sample_ids_mappings: instance to sample ids mappings
779
- sample_ids: sample ids array
780
- """
781
810
  preprocess_result = self._preprocess_result()
782
- preprocess_state = preprocess_result[state]
783
- return preprocess_state.sample_ids_to_instance_mappings, preprocess_state.instance_to_sample_ids_mappings
811
+ preprocess_state = preprocess_result.get(state)
812
+ if preprocess_state is None:
813
+ return {}, {}
814
+ return (preprocess_state.sample_ids_to_instance_mappings or {},
815
+ preprocess_state.instance_to_sample_ids_mappings or {})
784
816
 
@@ -154,8 +154,8 @@ class LeapLoaderBase:
154
154
  pass
155
155
 
156
156
  @abstractmethod
157
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
158
- # type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
157
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
158
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
159
159
  pass
160
160
 
161
161
  def is_custom_latent_space(self) -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.183.dev0
3
+ Version: 1.0.184.dev1
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -23,8 +23,8 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
23
23
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
24
24
  code_loader/inner_leap_binder/leapbinder.py,sha256=XLYYcV50qjMvoC1S6WW0tLBch_0g5gl1UyHiVSWYbvg,40491
25
25
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YlK51b5Ryo396f7tOA7Ole3vYCLs3f5ZLm2qDQ9K1NE,105781
26
- code_loader/leaploader.py,sha256=mPTi72odXIhPfYpiwGsuDklTBKxnQA9eFhGjsmuGzAU,39942
27
- code_loader/leaploaderbase.py,sha256=HRmh62g7kVtiD4a8eSB3dTCSnCGz-3Uar0wVItBdhhY,6530
26
+ code_loader/leaploader.py,sha256=9QAG3eJHdV8HleBybIwUKqfkhIjnoGIPL9aLb4nNSEc,41626
27
+ code_loader/leaploaderbase.py,sha256=l36qDA00GhZEG5NLKpEtAXgWJA-UQQIhNFGxywK7mUA,6530
28
28
  code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
29
29
  code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  code_loader/plot_functions/plot_functions.py,sha256=6Q7VWGxetL2W0EK2QeCdObVATvBuHs3YBA09H4uoIk0,14996
@@ -32,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
32
32
  code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
33
33
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
35
- code_loader-1.0.183.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
- code_loader-1.0.183.dev0.dist-info/METADATA,sha256=iN7D5iiqQv7iPN0hABw1u6ALg4FKnFfJTZiYrgp1km0,1095
37
- code_loader-1.0.183.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
- code_loader-1.0.183.dev0.dist-info/RECORD,,
35
+ code_loader-1.0.184.dev1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
+ code_loader-1.0.184.dev1.dist-info/METADATA,sha256=I5WOCfghR7fcWie_Bhrjrbu420bwjKBKLjaEFFpPOm0,1095
37
+ code_loader-1.0.184.dev1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
+ code_loader-1.0.184.dev1.dist-info/RECORD,,