code-loader 1.0.183.dev0__py3-none-any.whl → 1.0.184.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/leaploader.py +62 -25
- code_loader/leaploaderbase.py +2 -2
- {code_loader-1.0.183.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/METADATA +1 -1
- {code_loader-1.0.183.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/RECORD +6 -6
- {code_loader-1.0.183.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/LICENSE +0 -0
- {code_loader-1.0.183.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/WHEEL +0 -0
code_loader/leaploader.py
CHANGED
|
@@ -44,7 +44,7 @@ class LeapLoader(LeapLoaderBase):
|
|
|
44
44
|
super().__init__(code_path, code_entry_name)
|
|
45
45
|
|
|
46
46
|
self._preprocess_result_cached = None
|
|
47
|
-
self.
|
|
47
|
+
self._synthetic_lookup: Dict[str, Tuple[PreprocessResponse, Any]] = {}
|
|
48
48
|
|
|
49
49
|
try:
|
|
50
50
|
from code_loader.mixpanel_tracker import track_code_loader_loaded
|
|
@@ -191,22 +191,11 @@ class LeapLoader(LeapLoaderBase):
|
|
|
191
191
|
def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
|
|
192
192
|
self.exec_script()
|
|
193
193
|
|
|
194
|
-
if isinstance(sample_id, str) and sample_id.
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
sim_preprocess = self._synthetic_responses.get(prefix)
|
|
200
|
-
if sim_preprocess is not None:
|
|
201
|
-
try:
|
|
202
|
-
local_idx = int(local_idx_str)
|
|
203
|
-
original_sample_id = sim_preprocess.sample_ids[local_idx]
|
|
204
|
-
except (ValueError, IndexError):
|
|
205
|
-
pass
|
|
206
|
-
else:
|
|
207
|
-
return self._get_sample_from_preprocess(
|
|
208
|
-
sim_preprocess, original_sample_id, synthetic_index=sample_id
|
|
209
|
-
)
|
|
194
|
+
if isinstance(sample_id, str) and sample_id in self._synthetic_lookup:
|
|
195
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
196
|
+
return self._get_sample_from_preprocess(
|
|
197
|
+
sim_preprocess, original_local_id, synthetic_index=sample_id
|
|
198
|
+
)
|
|
210
199
|
|
|
211
200
|
preprocess_result = self._preprocess_result()
|
|
212
201
|
if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
|
|
@@ -408,8 +397,8 @@ class LeapLoader(LeapLoaderBase):
|
|
|
408
397
|
result_payloads.append(test_result)
|
|
409
398
|
return result_payloads
|
|
410
399
|
|
|
411
|
-
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0,
|
|
412
|
-
# type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
|
|
400
|
+
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
|
|
401
|
+
# type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
|
|
413
402
|
self.exec_script()
|
|
414
403
|
sim = next(
|
|
415
404
|
(s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
|
|
@@ -433,13 +422,43 @@ class LeapLoader(LeapLoaderBase):
|
|
|
433
422
|
for handler in global_leap_binder.setup_container.inputs:
|
|
434
423
|
per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
|
|
435
424
|
encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
|
|
436
|
-
if
|
|
437
|
-
|
|
438
|
-
|
|
425
|
+
if sample_ids is not None:
|
|
426
|
+
if len(sample_ids) != len(original_sample_ids):
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"sample_ids length ({}) does not match simulation output length ({})".format(
|
|
429
|
+
len(sample_ids), len(original_sample_ids)
|
|
430
|
+
)
|
|
431
|
+
)
|
|
432
|
+
for sid in sample_ids:
|
|
433
|
+
if not isinstance(sid, str):
|
|
434
|
+
raise TypeError(
|
|
435
|
+
"All sample_ids must be of type str. Got: {}".format(type(sid))
|
|
436
|
+
)
|
|
437
|
+
for synth_id, original_local_id in zip(sample_ids, original_sample_ids):
|
|
438
|
+
self._synthetic_lookup[synth_id] = (sim_preprocess, original_local_id)
|
|
439
|
+
self._extend_additional_preprocess(list(sample_ids))
|
|
440
|
+
returned_sample_ids = list(sample_ids)
|
|
439
441
|
else:
|
|
440
442
|
returned_sample_ids = original_sample_ids
|
|
441
443
|
return {"encoded": encoded, "sample_ids": returned_sample_ids}
|
|
442
444
|
|
|
445
|
+
def _extend_additional_preprocess(self, new_sample_ids: List[str]) -> None:
|
|
446
|
+
if self._preprocess_result_cached is None:
|
|
447
|
+
self._preprocess_result()
|
|
448
|
+
additional = self._preprocess_result_cached.get(DataStateEnum.additional)
|
|
449
|
+
if additional is None:
|
|
450
|
+
placeholder = PreprocessResponse(
|
|
451
|
+
sample_ids=list(new_sample_ids),
|
|
452
|
+
data={},
|
|
453
|
+
state=DataStateType.additional,
|
|
454
|
+
sample_id_type=str,
|
|
455
|
+
)
|
|
456
|
+
placeholder.tl_generated = True
|
|
457
|
+
self._preprocess_result_cached[DataStateEnum.additional] = placeholder
|
|
458
|
+
else:
|
|
459
|
+
additional.sample_ids = list(additional.sample_ids) + list(new_sample_ids)
|
|
460
|
+
additional.length = len(additional.sample_ids)
|
|
461
|
+
|
|
443
462
|
@staticmethod
|
|
444
463
|
def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
|
|
445
464
|
all_dataset_base_handlers: List[Union[DatasetBaseHandler, MetadataHandler]] = []
|
|
@@ -653,7 +672,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
653
672
|
state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
|
|
654
673
|
result_agg = {}
|
|
655
674
|
preprocess_result = self._preprocess_result()
|
|
656
|
-
|
|
675
|
+
if (state == DataStateEnum.additional and isinstance(sample_id, str)
|
|
676
|
+
and sample_id in self._synthetic_lookup):
|
|
677
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
678
|
+
preprocess_state = sim_preprocess
|
|
679
|
+
sample_id = original_local_id
|
|
680
|
+
else:
|
|
681
|
+
preprocess_state = preprocess_result[state]
|
|
657
682
|
for handler in handlers:
|
|
658
683
|
handler_result = handler.function(sample_id, preprocess_state)
|
|
659
684
|
handler_name = handler.name
|
|
@@ -667,7 +692,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
667
692
|
if instance_id is None:
|
|
668
693
|
return None
|
|
669
694
|
preprocess_result = self._preprocess_result()
|
|
670
|
-
|
|
695
|
+
if (state == DataStateEnum.additional and isinstance(sample_id, str)
|
|
696
|
+
and sample_id in self._synthetic_lookup):
|
|
697
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
698
|
+
preprocess_state = sim_preprocess
|
|
699
|
+
sample_id = original_local_id
|
|
700
|
+
else:
|
|
701
|
+
preprocess_state = preprocess_result[state]
|
|
671
702
|
result_agg = {}
|
|
672
703
|
for handler in global_leap_binder.setup_container.instance_masks:
|
|
673
704
|
handler_result = handler.function(sample_id, preprocess_state, instance_id)
|
|
@@ -727,7 +758,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
727
758
|
result_agg = {}
|
|
728
759
|
is_none = {}
|
|
729
760
|
preprocess_result = self._preprocess_result()
|
|
730
|
-
|
|
761
|
+
if (state == DataStateEnum.additional and isinstance(sample_id, str)
|
|
762
|
+
and sample_id in self._synthetic_lookup):
|
|
763
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
764
|
+
preprocess_state = sim_preprocess
|
|
765
|
+
sample_id = original_local_id
|
|
766
|
+
else:
|
|
767
|
+
preprocess_state = preprocess_result[state]
|
|
731
768
|
for handler in global_leap_binder.setup_container.metadata:
|
|
732
769
|
if requested_metadata_names:
|
|
733
770
|
if not is_metadata_name_starts_with_handler_name(handler):
|
code_loader/leaploaderbase.py
CHANGED
|
@@ -154,8 +154,8 @@ class LeapLoaderBase:
|
|
|
154
154
|
pass
|
|
155
155
|
|
|
156
156
|
@abstractmethod
|
|
157
|
-
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0,
|
|
158
|
-
# type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
|
|
157
|
+
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
|
|
158
|
+
# type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
|
|
159
159
|
pass
|
|
160
160
|
|
|
161
161
|
def is_custom_latent_space(self) -> bool:
|
|
@@ -23,8 +23,8 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
|
|
|
23
23
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
|
24
24
|
code_loader/inner_leap_binder/leapbinder.py,sha256=XLYYcV50qjMvoC1S6WW0tLBch_0g5gl1UyHiVSWYbvg,40491
|
|
25
25
|
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YlK51b5Ryo396f7tOA7Ole3vYCLs3f5ZLm2qDQ9K1NE,105781
|
|
26
|
-
code_loader/leaploader.py,sha256=
|
|
27
|
-
code_loader/leaploaderbase.py,sha256=
|
|
26
|
+
code_loader/leaploader.py,sha256=mFeCbDeyJmk3F1mR-CDJ3BBgJ_GTdTjnR7xR7j66Ru0,41781
|
|
27
|
+
code_loader/leaploaderbase.py,sha256=l36qDA00GhZEG5NLKpEtAXgWJA-UQQIhNFGxywK7mUA,6530
|
|
28
28
|
code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
|
|
29
29
|
code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
code_loader/plot_functions/plot_functions.py,sha256=6Q7VWGxetL2W0EK2QeCdObVATvBuHs3YBA09H4uoIk0,14996
|
|
@@ -32,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
|
|
|
32
32
|
code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
|
|
33
33
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
34
|
code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
|
|
35
|
-
code_loader-1.0.
|
|
36
|
-
code_loader-1.0.
|
|
37
|
-
code_loader-1.0.
|
|
38
|
-
code_loader-1.0.
|
|
35
|
+
code_loader-1.0.184.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
|
36
|
+
code_loader-1.0.184.dev0.dist-info/METADATA,sha256=dFaZKWGS-beIlKCmQI3RON7Iq_EalcHluBC-j9FVL3k,1095
|
|
37
|
+
code_loader-1.0.184.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
38
|
+
code_loader-1.0.184.dev0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|