code-loader 1.0.183.dev0__tar.gz → 1.0.184.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/PKG-INFO +1 -1
  2. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/leaploader.py +62 -25
  3. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/leaploaderbase.py +2 -2
  4. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/pyproject.toml +1 -1
  5. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/LICENSE +0 -0
  6. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/README.md +0 -0
  7. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/__init__.py +0 -0
  8. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/__init__.py +0 -0
  9. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/datasetclasses.py +0 -0
  10. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/enums.py +0 -0
  11. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/exceptions.py +0 -0
  12. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/mapping.py +0 -0
  13. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/responsedataclasses.py +0 -0
  14. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/sim_config.py +0 -0
  15. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/visualizer_classes.py +0 -0
  16. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/default_losses.py +0 -0
  17. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/default_metrics.py +0 -0
  18. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/__init__.py +0 -0
  19. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/api.py +0 -0
  20. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/cli_config_utils.py +0 -0
  21. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/client.py +0 -0
  22. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/epoch.py +0 -0
  23. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/experiment.py +0 -0
  24. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/experiment_context.py +0 -0
  25. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/types.py +0 -0
  26. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/utils.py +0 -0
  27. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  28. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/inner_leap_binder/__init__.py +0 -0
  29. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/inner_leap_binder/leapbinder.py +0 -0
  30. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/inner_leap_binder/leapbinder_decorators.py +0 -0
  31. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/mixpanel_tracker.py +0 -0
  32. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/plot_functions/__init__.py +0 -0
  33. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/plot_functions/plot_functions.py +0 -0
  34. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/plot_functions/visualize.py +0 -0
  35. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/utils.py +0 -0
  36. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/visualizers/__init__.py +0 -0
  37. {code_loader-1.0.183.dev0 → code_loader-1.0.184.dev0}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.183.dev0
3
+ Version: 1.0.184.dev0
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -44,7 +44,7 @@ class LeapLoader(LeapLoaderBase):
44
44
  super().__init__(code_path, code_entry_name)
45
45
 
46
46
  self._preprocess_result_cached = None
47
- self._synthetic_responses: Dict[str, PreprocessResponse] = {}
47
+ self._synthetic_lookup: Dict[str, Tuple[PreprocessResponse, Any]] = {}
48
48
 
49
49
  try:
50
50
  from code_loader.mixpanel_tracker import track_code_loader_loaded
@@ -191,22 +191,11 @@ class LeapLoader(LeapLoaderBase):
191
191
  def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
192
192
  self.exec_script()
193
193
 
194
- if isinstance(sample_id, str) and sample_id.startswith("synthetic_"):
195
- last_underscore = sample_id.rfind("_")
196
- if last_underscore != -1:
197
- prefix = sample_id[:last_underscore]
198
- local_idx_str = sample_id[last_underscore + 1:]
199
- sim_preprocess = self._synthetic_responses.get(prefix)
200
- if sim_preprocess is not None:
201
- try:
202
- local_idx = int(local_idx_str)
203
- original_sample_id = sim_preprocess.sample_ids[local_idx]
204
- except (ValueError, IndexError):
205
- pass
206
- else:
207
- return self._get_sample_from_preprocess(
208
- sim_preprocess, original_sample_id, synthetic_index=sample_id
209
- )
194
+ if isinstance(sample_id, str) and sample_id in self._synthetic_lookup:
195
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
196
+ return self._get_sample_from_preprocess(
197
+ sim_preprocess, original_local_id, synthetic_index=sample_id
198
+ )
210
199
 
211
200
  preprocess_result = self._preprocess_result()
212
201
  if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
@@ -408,8 +397,8 @@ class LeapLoader(LeapLoaderBase):
408
397
  result_payloads.append(test_result)
409
398
  return result_payloads
410
399
 
411
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
412
- # type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
400
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
401
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
413
402
  self.exec_script()
414
403
  sim = next(
415
404
  (s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
@@ -433,13 +422,43 @@ class LeapLoader(LeapLoaderBase):
433
422
  for handler in global_leap_binder.setup_container.inputs:
434
423
  per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
435
424
  encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
436
- if sample_id_prefix is not None:
437
- returned_sample_ids = ["{}_{}".format(sample_id_prefix, i) for i in range(len(original_sample_ids))]
438
- self._synthetic_responses[sample_id_prefix] = sim_preprocess
425
+ if sample_ids is not None:
426
+ if len(sample_ids) != len(original_sample_ids):
427
+ raise ValueError(
428
+ "sample_ids length ({}) does not match simulation output length ({})".format(
429
+ len(sample_ids), len(original_sample_ids)
430
+ )
431
+ )
432
+ for sid in sample_ids:
433
+ if not isinstance(sid, str):
434
+ raise TypeError(
435
+ "All sample_ids must be of type str. Got: {}".format(type(sid))
436
+ )
437
+ for synth_id, original_local_id in zip(sample_ids, original_sample_ids):
438
+ self._synthetic_lookup[synth_id] = (sim_preprocess, original_local_id)
439
+ self._extend_additional_preprocess(list(sample_ids))
440
+ returned_sample_ids = list(sample_ids)
439
441
  else:
440
442
  returned_sample_ids = original_sample_ids
441
443
  return {"encoded": encoded, "sample_ids": returned_sample_ids}
442
444
 
445
+ def _extend_additional_preprocess(self, new_sample_ids: List[str]) -> None:
446
+ if self._preprocess_result_cached is None:
447
+ self._preprocess_result()
448
+ additional = self._preprocess_result_cached.get(DataStateEnum.additional)
449
+ if additional is None:
450
+ placeholder = PreprocessResponse(
451
+ sample_ids=list(new_sample_ids),
452
+ data={},
453
+ state=DataStateType.additional,
454
+ sample_id_type=str,
455
+ )
456
+ placeholder.tl_generated = True
457
+ self._preprocess_result_cached[DataStateEnum.additional] = placeholder
458
+ else:
459
+ additional.sample_ids = list(additional.sample_ids) + list(new_sample_ids)
460
+ additional.length = len(additional.sample_ids)
461
+
443
462
  @staticmethod
444
463
  def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
445
464
  all_dataset_base_handlers: List[Union[DatasetBaseHandler, MetadataHandler]] = []
@@ -653,7 +672,13 @@ class LeapLoader(LeapLoaderBase):
653
672
  state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
654
673
  result_agg = {}
655
674
  preprocess_result = self._preprocess_result()
656
- preprocess_state = preprocess_result[state]
675
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
676
+ and sample_id in self._synthetic_lookup):
677
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
678
+ preprocess_state = sim_preprocess
679
+ sample_id = original_local_id
680
+ else:
681
+ preprocess_state = preprocess_result[state]
657
682
  for handler in handlers:
658
683
  handler_result = handler.function(sample_id, preprocess_state)
659
684
  handler_name = handler.name
@@ -667,7 +692,13 @@ class LeapLoader(LeapLoaderBase):
667
692
  if instance_id is None:
668
693
  return None
669
694
  preprocess_result = self._preprocess_result()
670
- preprocess_state = preprocess_result[state]
695
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
696
+ and sample_id in self._synthetic_lookup):
697
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
698
+ preprocess_state = sim_preprocess
699
+ sample_id = original_local_id
700
+ else:
701
+ preprocess_state = preprocess_result[state]
671
702
  result_agg = {}
672
703
  for handler in global_leap_binder.setup_container.instance_masks:
673
704
  handler_result = handler.function(sample_id, preprocess_state, instance_id)
@@ -727,7 +758,13 @@ class LeapLoader(LeapLoaderBase):
727
758
  result_agg = {}
728
759
  is_none = {}
729
760
  preprocess_result = self._preprocess_result()
730
- preprocess_state = preprocess_result[state]
761
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
762
+ and sample_id in self._synthetic_lookup):
763
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
764
+ preprocess_state = sim_preprocess
765
+ sample_id = original_local_id
766
+ else:
767
+ preprocess_state = preprocess_result[state]
731
768
  for handler in global_leap_binder.setup_container.metadata:
732
769
  if requested_metadata_names:
733
770
  if not is_metadata_name_starts_with_handler_name(handler):
@@ -154,8 +154,8 @@ class LeapLoaderBase:
154
154
  pass
155
155
 
156
156
  @abstractmethod
157
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
158
- # type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
157
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
158
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
159
159
  pass
160
160
 
161
161
  def is_custom_latent_space(self) -> bool:
@@ -1,7 +1,7 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
3
 
4
- version = "1.0.183.dev0"
4
+ version = "1.0.184.dev0"
5
5
  description = ""
6
6
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
7
7
  license = "MIT"