code-loader 1.0.182.dev0__tar.gz → 1.0.184.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/PKG-INFO +1 -1
  2. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/leaploader.py +114 -7
  3. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/leaploaderbase.py +2 -2
  4. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/pyproject.toml +1 -1
  5. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/LICENSE +0 -0
  6. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/README.md +0 -0
  7. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/__init__.py +0 -0
  8. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/__init__.py +0 -0
  9. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/datasetclasses.py +0 -0
  10. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/enums.py +0 -0
  11. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/exceptions.py +0 -0
  12. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/mapping.py +0 -0
  13. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/responsedataclasses.py +0 -0
  14. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/sim_config.py +0 -0
  15. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/contract/visualizer_classes.py +0 -0
  16. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/default_losses.py +0 -0
  17. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/default_metrics.py +0 -0
  18. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/__init__.py +0 -0
  19. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/api.py +0 -0
  20. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/cli_config_utils.py +0 -0
  21. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/client.py +0 -0
  22. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/epoch.py +0 -0
  23. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/experiment.py +0 -0
  24. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/experiment_context.py +0 -0
  25. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/types.py +0 -0
  26. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/utils.py +0 -0
  27. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  28. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/inner_leap_binder/__init__.py +0 -0
  29. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/inner_leap_binder/leapbinder.py +0 -0
  30. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/inner_leap_binder/leapbinder_decorators.py +0 -0
  31. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/mixpanel_tracker.py +0 -0
  32. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/plot_functions/__init__.py +0 -0
  33. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/plot_functions/plot_functions.py +0 -0
  34. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/plot_functions/visualize.py +0 -0
  35. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/utils.py +0 -0
  36. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/visualizers/__init__.py +0 -0
  37. {code_loader-1.0.182.dev0 → code_loader-1.0.184.dev0}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.182.dev0
3
+ Version: 1.0.184.dev0
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -44,6 +44,7 @@ class LeapLoader(LeapLoaderBase):
44
44
  super().__init__(code_path, code_entry_name)
45
45
 
46
46
  self._preprocess_result_cached = None
47
+ self._synthetic_lookup: Dict[str, Tuple[PreprocessResponse, Any]] = {}
47
48
 
48
49
  try:
49
50
  from code_loader.mixpanel_tracker import track_code_loader_loaded
@@ -189,6 +190,13 @@ class LeapLoader(LeapLoaderBase):
189
190
 
190
191
  def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
191
192
  self.exec_script()
193
+
194
+ if isinstance(sample_id, str) and sample_id in self._synthetic_lookup:
195
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
196
+ return self._get_sample_from_preprocess(
197
+ sim_preprocess, original_local_id, synthetic_index=sample_id
198
+ )
199
+
192
200
  preprocess_result = self._preprocess_result()
193
201
  if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
194
202
  self._preprocess_result(update_unlabeled_preprocess=True)
@@ -211,6 +219,50 @@ class LeapLoader(LeapLoaderBase):
211
219
  instance_masks=instance_mask)
212
220
  return sample
213
221
 
222
+ def _get_sample_from_preprocess(
223
+ self,
224
+ preprocess: "PreprocessResponse",
225
+ original_sample_id: Union[int, str],
226
+ synthetic_index: str,
227
+ ) -> DatasetSample:
228
+ inputs = {}
229
+ for handler in global_leap_binder.setup_container.inputs:
230
+ inputs[handler.name] = handler.function(original_sample_id, preprocess)
231
+
232
+ gt = {}
233
+ for handler in global_leap_binder.setup_container.ground_truths:
234
+ gt[handler.name] = handler.function(original_sample_id, preprocess)
235
+
236
+ metadata = {}
237
+ metadata_is_none = {}
238
+ for handler in global_leap_binder.setup_container.metadata:
239
+ handler_result = handler.function(original_sample_id, preprocess)
240
+ if isinstance(handler_result, dict):
241
+ for k, v in handler_result.items():
242
+ key = "{}_{}".format(handler.name, k)
243
+ metadata[key], metadata_is_none[key] = self._convert_metadata_to_correct_type(key, v)
244
+ else:
245
+ metadata[handler.name], metadata_is_none[handler.name] = self._convert_metadata_to_correct_type(
246
+ handler.name, handler_result
247
+ )
248
+
249
+ custom_latent_space = None
250
+ if global_leap_binder.setup_container.custom_latent_space is not None:
251
+ custom_latent_space = global_leap_binder.setup_container.custom_latent_space.function(
252
+ original_sample_id, preprocess
253
+ )
254
+
255
+ return DatasetSample(
256
+ inputs=inputs,
257
+ gt=gt if gt else None,
258
+ metadata=metadata,
259
+ metadata_is_none=metadata_is_none,
260
+ index=synthetic_index,
261
+ state=DataStateEnum.additional,
262
+ custom_latent_space=custom_latent_space,
263
+ instance_masks=None,
264
+ )
265
+
214
266
  def check_dataset(self) -> DatasetIntegParseResult:
215
267
  test_payloads: List[DatasetTestResultPayload] = []
216
268
  setup_response = None
@@ -345,8 +397,8 @@ class LeapLoader(LeapLoaderBase):
345
397
  result_payloads.append(test_result)
346
398
  return result_payloads
347
399
 
348
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
349
- # type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str, npt.NDArray[np.float32]]
400
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
401
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
350
402
  self.exec_script()
351
403
  sim = next(
352
404
  (s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
@@ -364,11 +416,48 @@ class LeapLoader(LeapLoaderBase):
364
416
  _simulation_context["active"] = False
365
417
  sim_preprocess.state = DataStateType.additional
366
418
  sim_preprocess.tl_generated = True
419
+ original_sample_ids = list(sim_preprocess.sample_ids)
367
420
  per_encoder = {handler.name: [] for handler in global_leap_binder.setup_container.inputs}
368
- for sample_id in sim_preprocess.sample_ids:
421
+ for sample_id in original_sample_ids:
369
422
  for handler in global_leap_binder.setup_container.inputs:
370
423
  per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
371
- return {name: np.stack(arrays) for name, arrays in per_encoder.items()}
424
+ encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
425
+ if sample_ids is not None:
426
+ if len(sample_ids) != len(original_sample_ids):
427
+ raise ValueError(
428
+ "sample_ids length ({}) does not match simulation output length ({})".format(
429
+ len(sample_ids), len(original_sample_ids)
430
+ )
431
+ )
432
+ for sid in sample_ids:
433
+ if not isinstance(sid, str):
434
+ raise TypeError(
435
+ "All sample_ids must be of type str. Got: {}".format(type(sid))
436
+ )
437
+ for synth_id, original_local_id in zip(sample_ids, original_sample_ids):
438
+ self._synthetic_lookup[synth_id] = (sim_preprocess, original_local_id)
439
+ self._extend_additional_preprocess(list(sample_ids))
440
+ returned_sample_ids = list(sample_ids)
441
+ else:
442
+ returned_sample_ids = original_sample_ids
443
+ return {"encoded": encoded, "sample_ids": returned_sample_ids}
444
+
445
+ def _extend_additional_preprocess(self, new_sample_ids: List[str]) -> None:
446
+ if self._preprocess_result_cached is None:
447
+ self._preprocess_result()
448
+ additional = self._preprocess_result_cached.get(DataStateEnum.additional)
449
+ if additional is None:
450
+ placeholder = PreprocessResponse(
451
+ sample_ids=list(new_sample_ids),
452
+ data={},
453
+ state=DataStateType.additional,
454
+ sample_id_type=str,
455
+ )
456
+ placeholder.tl_generated = True
457
+ self._preprocess_result_cached[DataStateEnum.additional] = placeholder
458
+ else:
459
+ additional.sample_ids = list(additional.sample_ids) + list(new_sample_ids)
460
+ additional.length = len(additional.sample_ids)
372
461
 
373
462
  @staticmethod
374
463
  def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
@@ -583,7 +672,13 @@ class LeapLoader(LeapLoaderBase):
583
672
  state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
584
673
  result_agg = {}
585
674
  preprocess_result = self._preprocess_result()
586
- preprocess_state = preprocess_result[state]
675
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
676
+ and sample_id in self._synthetic_lookup):
677
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
678
+ preprocess_state = sim_preprocess
679
+ sample_id = original_local_id
680
+ else:
681
+ preprocess_state = preprocess_result[state]
587
682
  for handler in handlers:
588
683
  handler_result = handler.function(sample_id, preprocess_state)
589
684
  handler_name = handler.name
@@ -597,7 +692,13 @@ class LeapLoader(LeapLoaderBase):
597
692
  if instance_id is None:
598
693
  return None
599
694
  preprocess_result = self._preprocess_result()
600
- preprocess_state = preprocess_result[state]
695
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
696
+ and sample_id in self._synthetic_lookup):
697
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
698
+ preprocess_state = sim_preprocess
699
+ sample_id = original_local_id
700
+ else:
701
+ preprocess_state = preprocess_result[state]
601
702
  result_agg = {}
602
703
  for handler in global_leap_binder.setup_container.instance_masks:
603
704
  handler_result = handler.function(sample_id, preprocess_state, instance_id)
@@ -657,7 +758,13 @@ class LeapLoader(LeapLoaderBase):
657
758
  result_agg = {}
658
759
  is_none = {}
659
760
  preprocess_result = self._preprocess_result()
660
- preprocess_state = preprocess_result[state]
761
+ if (state == DataStateEnum.additional and isinstance(sample_id, str)
762
+ and sample_id in self._synthetic_lookup):
763
+ sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
764
+ preprocess_state = sim_preprocess
765
+ sample_id = original_local_id
766
+ else:
767
+ preprocess_state = preprocess_result[state]
661
768
  for handler in global_leap_binder.setup_container.metadata:
662
769
  if requested_metadata_names:
663
770
  if not is_metadata_name_starts_with_handler_name(handler):
@@ -154,8 +154,8 @@ class LeapLoaderBase:
154
154
  pass
155
155
 
156
156
  @abstractmethod
157
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
158
- # type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str, Any]
157
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
158
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
159
159
  pass
160
160
 
161
161
  def is_custom_latent_space(self) -> bool:
@@ -1,7 +1,7 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
3
 
4
- version = "1.0.182.dev0"
4
+ version = "1.0.184.dev0"
5
5
  description = ""
6
6
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
7
7
  license = "MIT"