code-loader 1.0.182.dev0__py3-none-any.whl → 1.0.184.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/leaploader.py +114 -7
- code_loader/leaploaderbase.py +2 -2
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/METADATA +1 -1
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/RECORD +6 -6
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/LICENSE +0 -0
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.184.dev0.dist-info}/WHEEL +0 -0
code_loader/leaploader.py
CHANGED
|
@@ -44,6 +44,7 @@ class LeapLoader(LeapLoaderBase):
|
|
|
44
44
|
super().__init__(code_path, code_entry_name)
|
|
45
45
|
|
|
46
46
|
self._preprocess_result_cached = None
|
|
47
|
+
self._synthetic_lookup: Dict[str, Tuple[PreprocessResponse, Any]] = {}
|
|
47
48
|
|
|
48
49
|
try:
|
|
49
50
|
from code_loader.mixpanel_tracker import track_code_loader_loaded
|
|
@@ -189,6 +190,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
189
190
|
|
|
190
191
|
def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
|
|
191
192
|
self.exec_script()
|
|
193
|
+
|
|
194
|
+
if isinstance(sample_id, str) and sample_id in self._synthetic_lookup:
|
|
195
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
196
|
+
return self._get_sample_from_preprocess(
|
|
197
|
+
sim_preprocess, original_local_id, synthetic_index=sample_id
|
|
198
|
+
)
|
|
199
|
+
|
|
192
200
|
preprocess_result = self._preprocess_result()
|
|
193
201
|
if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
|
|
194
202
|
self._preprocess_result(update_unlabeled_preprocess=True)
|
|
@@ -211,6 +219,50 @@ class LeapLoader(LeapLoaderBase):
|
|
|
211
219
|
instance_masks=instance_mask)
|
|
212
220
|
return sample
|
|
213
221
|
|
|
222
|
+
def _get_sample_from_preprocess(
|
|
223
|
+
self,
|
|
224
|
+
preprocess: "PreprocessResponse",
|
|
225
|
+
original_sample_id: Union[int, str],
|
|
226
|
+
synthetic_index: str,
|
|
227
|
+
) -> DatasetSample:
|
|
228
|
+
inputs = {}
|
|
229
|
+
for handler in global_leap_binder.setup_container.inputs:
|
|
230
|
+
inputs[handler.name] = handler.function(original_sample_id, preprocess)
|
|
231
|
+
|
|
232
|
+
gt = {}
|
|
233
|
+
for handler in global_leap_binder.setup_container.ground_truths:
|
|
234
|
+
gt[handler.name] = handler.function(original_sample_id, preprocess)
|
|
235
|
+
|
|
236
|
+
metadata = {}
|
|
237
|
+
metadata_is_none = {}
|
|
238
|
+
for handler in global_leap_binder.setup_container.metadata:
|
|
239
|
+
handler_result = handler.function(original_sample_id, preprocess)
|
|
240
|
+
if isinstance(handler_result, dict):
|
|
241
|
+
for k, v in handler_result.items():
|
|
242
|
+
key = "{}_{}".format(handler.name, k)
|
|
243
|
+
metadata[key], metadata_is_none[key] = self._convert_metadata_to_correct_type(key, v)
|
|
244
|
+
else:
|
|
245
|
+
metadata[handler.name], metadata_is_none[handler.name] = self._convert_metadata_to_correct_type(
|
|
246
|
+
handler.name, handler_result
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
custom_latent_space = None
|
|
250
|
+
if global_leap_binder.setup_container.custom_latent_space is not None:
|
|
251
|
+
custom_latent_space = global_leap_binder.setup_container.custom_latent_space.function(
|
|
252
|
+
original_sample_id, preprocess
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return DatasetSample(
|
|
256
|
+
inputs=inputs,
|
|
257
|
+
gt=gt if gt else None,
|
|
258
|
+
metadata=metadata,
|
|
259
|
+
metadata_is_none=metadata_is_none,
|
|
260
|
+
index=synthetic_index,
|
|
261
|
+
state=DataStateEnum.additional,
|
|
262
|
+
custom_latent_space=custom_latent_space,
|
|
263
|
+
instance_masks=None,
|
|
264
|
+
)
|
|
265
|
+
|
|
214
266
|
def check_dataset(self) -> DatasetIntegParseResult:
|
|
215
267
|
test_payloads: List[DatasetTestResultPayload] = []
|
|
216
268
|
setup_response = None
|
|
@@ -345,8 +397,8 @@ class LeapLoader(LeapLoaderBase):
|
|
|
345
397
|
result_payloads.append(test_result)
|
|
346
398
|
return result_payloads
|
|
347
399
|
|
|
348
|
-
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
|
|
349
|
-
# type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str,
|
|
400
|
+
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
|
|
401
|
+
# type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
|
|
350
402
|
self.exec_script()
|
|
351
403
|
sim = next(
|
|
352
404
|
(s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
|
|
@@ -364,11 +416,48 @@ class LeapLoader(LeapLoaderBase):
|
|
|
364
416
|
_simulation_context["active"] = False
|
|
365
417
|
sim_preprocess.state = DataStateType.additional
|
|
366
418
|
sim_preprocess.tl_generated = True
|
|
419
|
+
original_sample_ids = list(sim_preprocess.sample_ids)
|
|
367
420
|
per_encoder = {handler.name: [] for handler in global_leap_binder.setup_container.inputs}
|
|
368
|
-
for sample_id in
|
|
421
|
+
for sample_id in original_sample_ids:
|
|
369
422
|
for handler in global_leap_binder.setup_container.inputs:
|
|
370
423
|
per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
|
|
371
|
-
|
|
424
|
+
encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
|
|
425
|
+
if sample_ids is not None:
|
|
426
|
+
if len(sample_ids) != len(original_sample_ids):
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"sample_ids length ({}) does not match simulation output length ({})".format(
|
|
429
|
+
len(sample_ids), len(original_sample_ids)
|
|
430
|
+
)
|
|
431
|
+
)
|
|
432
|
+
for sid in sample_ids:
|
|
433
|
+
if not isinstance(sid, str):
|
|
434
|
+
raise TypeError(
|
|
435
|
+
"All sample_ids must be of type str. Got: {}".format(type(sid))
|
|
436
|
+
)
|
|
437
|
+
for synth_id, original_local_id in zip(sample_ids, original_sample_ids):
|
|
438
|
+
self._synthetic_lookup[synth_id] = (sim_preprocess, original_local_id)
|
|
439
|
+
self._extend_additional_preprocess(list(sample_ids))
|
|
440
|
+
returned_sample_ids = list(sample_ids)
|
|
441
|
+
else:
|
|
442
|
+
returned_sample_ids = original_sample_ids
|
|
443
|
+
return {"encoded": encoded, "sample_ids": returned_sample_ids}
|
|
444
|
+
|
|
445
|
+
def _extend_additional_preprocess(self, new_sample_ids: List[str]) -> None:
|
|
446
|
+
if self._preprocess_result_cached is None:
|
|
447
|
+
self._preprocess_result()
|
|
448
|
+
additional = self._preprocess_result_cached.get(DataStateEnum.additional)
|
|
449
|
+
if additional is None:
|
|
450
|
+
placeholder = PreprocessResponse(
|
|
451
|
+
sample_ids=list(new_sample_ids),
|
|
452
|
+
data={},
|
|
453
|
+
state=DataStateType.additional,
|
|
454
|
+
sample_id_type=str,
|
|
455
|
+
)
|
|
456
|
+
placeholder.tl_generated = True
|
|
457
|
+
self._preprocess_result_cached[DataStateEnum.additional] = placeholder
|
|
458
|
+
else:
|
|
459
|
+
additional.sample_ids = list(additional.sample_ids) + list(new_sample_ids)
|
|
460
|
+
additional.length = len(additional.sample_ids)
|
|
372
461
|
|
|
373
462
|
@staticmethod
|
|
374
463
|
def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
|
|
@@ -583,7 +672,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
583
672
|
state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
|
|
584
673
|
result_agg = {}
|
|
585
674
|
preprocess_result = self._preprocess_result()
|
|
586
|
-
|
|
675
|
+
if (state == DataStateEnum.additional and isinstance(sample_id, str)
|
|
676
|
+
and sample_id in self._synthetic_lookup):
|
|
677
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
678
|
+
preprocess_state = sim_preprocess
|
|
679
|
+
sample_id = original_local_id
|
|
680
|
+
else:
|
|
681
|
+
preprocess_state = preprocess_result[state]
|
|
587
682
|
for handler in handlers:
|
|
588
683
|
handler_result = handler.function(sample_id, preprocess_state)
|
|
589
684
|
handler_name = handler.name
|
|
@@ -597,7 +692,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
597
692
|
if instance_id is None:
|
|
598
693
|
return None
|
|
599
694
|
preprocess_result = self._preprocess_result()
|
|
600
|
-
|
|
695
|
+
if (state == DataStateEnum.additional and isinstance(sample_id, str)
|
|
696
|
+
and sample_id in self._synthetic_lookup):
|
|
697
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
698
|
+
preprocess_state = sim_preprocess
|
|
699
|
+
sample_id = original_local_id
|
|
700
|
+
else:
|
|
701
|
+
preprocess_state = preprocess_result[state]
|
|
601
702
|
result_agg = {}
|
|
602
703
|
for handler in global_leap_binder.setup_container.instance_masks:
|
|
603
704
|
handler_result = handler.function(sample_id, preprocess_state, instance_id)
|
|
@@ -657,7 +758,13 @@ class LeapLoader(LeapLoaderBase):
|
|
|
657
758
|
result_agg = {}
|
|
658
759
|
is_none = {}
|
|
659
760
|
preprocess_result = self._preprocess_result()
|
|
660
|
-
|
|
761
|
+
if (state == DataStateEnum.additional and isinstance(sample_id, str)
|
|
762
|
+
and sample_id in self._synthetic_lookup):
|
|
763
|
+
sim_preprocess, original_local_id = self._synthetic_lookup[sample_id]
|
|
764
|
+
preprocess_state = sim_preprocess
|
|
765
|
+
sample_id = original_local_id
|
|
766
|
+
else:
|
|
767
|
+
preprocess_state = preprocess_result[state]
|
|
661
768
|
for handler in global_leap_binder.setup_container.metadata:
|
|
662
769
|
if requested_metadata_names:
|
|
663
770
|
if not is_metadata_name_starts_with_handler_name(handler):
|
code_loader/leaploaderbase.py
CHANGED
|
@@ -154,8 +154,8 @@ class LeapLoaderBase:
|
|
|
154
154
|
pass
|
|
155
155
|
|
|
156
156
|
@abstractmethod
|
|
157
|
-
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
|
|
158
|
-
# type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str, Any]
|
|
157
|
+
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_ids=None):
|
|
158
|
+
# type: (str, Optional[Dict[str, Any]], int, int, Optional[List[str]]) -> Dict[str, Any]
|
|
159
159
|
pass
|
|
160
160
|
|
|
161
161
|
def is_custom_latent_space(self) -> bool:
|
|
@@ -23,8 +23,8 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
|
|
|
23
23
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
|
24
24
|
code_loader/inner_leap_binder/leapbinder.py,sha256=XLYYcV50qjMvoC1S6WW0tLBch_0g5gl1UyHiVSWYbvg,40491
|
|
25
25
|
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YlK51b5Ryo396f7tOA7Ole3vYCLs3f5ZLm2qDQ9K1NE,105781
|
|
26
|
-
code_loader/leaploader.py,sha256=
|
|
27
|
-
code_loader/leaploaderbase.py,sha256=
|
|
26
|
+
code_loader/leaploader.py,sha256=mFeCbDeyJmk3F1mR-CDJ3BBgJ_GTdTjnR7xR7j66Ru0,41781
|
|
27
|
+
code_loader/leaploaderbase.py,sha256=l36qDA00GhZEG5NLKpEtAXgWJA-UQQIhNFGxywK7mUA,6530
|
|
28
28
|
code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
|
|
29
29
|
code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
code_loader/plot_functions/plot_functions.py,sha256=6Q7VWGxetL2W0EK2QeCdObVATvBuHs3YBA09H4uoIk0,14996
|
|
@@ -32,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
|
|
|
32
32
|
code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
|
|
33
33
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
34
|
code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
|
|
35
|
-
code_loader-1.0.
|
|
36
|
-
code_loader-1.0.
|
|
37
|
-
code_loader-1.0.
|
|
38
|
-
code_loader-1.0.
|
|
35
|
+
code_loader-1.0.184.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
|
36
|
+
code_loader-1.0.184.dev0.dist-info/METADATA,sha256=dFaZKWGS-beIlKCmQI3RON7Iq_EalcHluBC-j9FVL3k,1095
|
|
37
|
+
code_loader-1.0.184.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
38
|
+
code_loader-1.0.184.dev0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|