code-loader 1.0.182.dev0__py3-none-any.whl → 1.0.183.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/leaploader.py +74 -4
- code_loader/leaploaderbase.py +2 -2
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.183.dev0.dist-info}/METADATA +1 -1
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.183.dev0.dist-info}/RECORD +6 -6
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.183.dev0.dist-info}/LICENSE +0 -0
- {code_loader-1.0.182.dev0.dist-info → code_loader-1.0.183.dev0.dist-info}/WHEEL +0 -0
code_loader/leaploader.py
CHANGED
|
@@ -44,6 +44,7 @@ class LeapLoader(LeapLoaderBase):
|
|
|
44
44
|
super().__init__(code_path, code_entry_name)
|
|
45
45
|
|
|
46
46
|
self._preprocess_result_cached = None
|
|
47
|
+
self._synthetic_responses: Dict[str, PreprocessResponse] = {}
|
|
47
48
|
|
|
48
49
|
try:
|
|
49
50
|
from code_loader.mixpanel_tracker import track_code_loader_loaded
|
|
@@ -189,6 +190,24 @@ class LeapLoader(LeapLoaderBase):
|
|
|
189
190
|
|
|
190
191
|
def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
|
|
191
192
|
self.exec_script()
|
|
193
|
+
|
|
194
|
+
if isinstance(sample_id, str) and sample_id.startswith("synthetic_"):
|
|
195
|
+
last_underscore = sample_id.rfind("_")
|
|
196
|
+
if last_underscore != -1:
|
|
197
|
+
prefix = sample_id[:last_underscore]
|
|
198
|
+
local_idx_str = sample_id[last_underscore + 1:]
|
|
199
|
+
sim_preprocess = self._synthetic_responses.get(prefix)
|
|
200
|
+
if sim_preprocess is not None:
|
|
201
|
+
try:
|
|
202
|
+
local_idx = int(local_idx_str)
|
|
203
|
+
original_sample_id = sim_preprocess.sample_ids[local_idx]
|
|
204
|
+
except (ValueError, IndexError):
|
|
205
|
+
pass
|
|
206
|
+
else:
|
|
207
|
+
return self._get_sample_from_preprocess(
|
|
208
|
+
sim_preprocess, original_sample_id, synthetic_index=sample_id
|
|
209
|
+
)
|
|
210
|
+
|
|
192
211
|
preprocess_result = self._preprocess_result()
|
|
193
212
|
if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
|
|
194
213
|
self._preprocess_result(update_unlabeled_preprocess=True)
|
|
@@ -211,6 +230,50 @@ class LeapLoader(LeapLoaderBase):
|
|
|
211
230
|
instance_masks=instance_mask)
|
|
212
231
|
return sample
|
|
213
232
|
|
|
233
|
+
def _get_sample_from_preprocess(
|
|
234
|
+
self,
|
|
235
|
+
preprocess: "PreprocessResponse",
|
|
236
|
+
original_sample_id: Union[int, str],
|
|
237
|
+
synthetic_index: str,
|
|
238
|
+
) -> DatasetSample:
|
|
239
|
+
inputs = {}
|
|
240
|
+
for handler in global_leap_binder.setup_container.inputs:
|
|
241
|
+
inputs[handler.name] = handler.function(original_sample_id, preprocess)
|
|
242
|
+
|
|
243
|
+
gt = {}
|
|
244
|
+
for handler in global_leap_binder.setup_container.ground_truths:
|
|
245
|
+
gt[handler.name] = handler.function(original_sample_id, preprocess)
|
|
246
|
+
|
|
247
|
+
metadata = {}
|
|
248
|
+
metadata_is_none = {}
|
|
249
|
+
for handler in global_leap_binder.setup_container.metadata:
|
|
250
|
+
handler_result = handler.function(original_sample_id, preprocess)
|
|
251
|
+
if isinstance(handler_result, dict):
|
|
252
|
+
for k, v in handler_result.items():
|
|
253
|
+
key = "{}_{}".format(handler.name, k)
|
|
254
|
+
metadata[key], metadata_is_none[key] = self._convert_metadata_to_correct_type(key, v)
|
|
255
|
+
else:
|
|
256
|
+
metadata[handler.name], metadata_is_none[handler.name] = self._convert_metadata_to_correct_type(
|
|
257
|
+
handler.name, handler_result
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
custom_latent_space = None
|
|
261
|
+
if global_leap_binder.setup_container.custom_latent_space is not None:
|
|
262
|
+
custom_latent_space = global_leap_binder.setup_container.custom_latent_space.function(
|
|
263
|
+
original_sample_id, preprocess
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return DatasetSample(
|
|
267
|
+
inputs=inputs,
|
|
268
|
+
gt=gt if gt else None,
|
|
269
|
+
metadata=metadata,
|
|
270
|
+
metadata_is_none=metadata_is_none,
|
|
271
|
+
index=synthetic_index,
|
|
272
|
+
state=DataStateEnum.additional,
|
|
273
|
+
custom_latent_space=custom_latent_space,
|
|
274
|
+
instance_masks=None,
|
|
275
|
+
)
|
|
276
|
+
|
|
214
277
|
def check_dataset(self) -> DatasetIntegParseResult:
|
|
215
278
|
test_payloads: List[DatasetTestResultPayload] = []
|
|
216
279
|
setup_response = None
|
|
@@ -345,8 +408,8 @@ class LeapLoader(LeapLoaderBase):
|
|
|
345
408
|
result_payloads.append(test_result)
|
|
346
409
|
return result_payloads
|
|
347
410
|
|
|
348
|
-
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
|
|
349
|
-
# type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str,
|
|
411
|
+
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
|
|
412
|
+
# type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
|
|
350
413
|
self.exec_script()
|
|
351
414
|
sim = next(
|
|
352
415
|
(s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
|
|
@@ -364,11 +427,18 @@ class LeapLoader(LeapLoaderBase):
|
|
|
364
427
|
_simulation_context["active"] = False
|
|
365
428
|
sim_preprocess.state = DataStateType.additional
|
|
366
429
|
sim_preprocess.tl_generated = True
|
|
430
|
+
original_sample_ids = list(sim_preprocess.sample_ids)
|
|
367
431
|
per_encoder = {handler.name: [] for handler in global_leap_binder.setup_container.inputs}
|
|
368
|
-
for sample_id in
|
|
432
|
+
for sample_id in original_sample_ids:
|
|
369
433
|
for handler in global_leap_binder.setup_container.inputs:
|
|
370
434
|
per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
|
|
371
|
-
|
|
435
|
+
encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
|
|
436
|
+
if sample_id_prefix is not None:
|
|
437
|
+
returned_sample_ids = ["{}_{}".format(sample_id_prefix, i) for i in range(len(original_sample_ids))]
|
|
438
|
+
self._synthetic_responses[sample_id_prefix] = sim_preprocess
|
|
439
|
+
else:
|
|
440
|
+
returned_sample_ids = original_sample_ids
|
|
441
|
+
return {"encoded": encoded, "sample_ids": returned_sample_ids}
|
|
372
442
|
|
|
373
443
|
@staticmethod
|
|
374
444
|
def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
|
code_loader/leaploaderbase.py
CHANGED
|
@@ -154,8 +154,8 @@ class LeapLoaderBase:
|
|
|
154
154
|
pass
|
|
155
155
|
|
|
156
156
|
@abstractmethod
|
|
157
|
-
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
|
|
158
|
-
# type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str, Any]
|
|
157
|
+
def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
|
|
158
|
+
# type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
|
|
159
159
|
pass
|
|
160
160
|
|
|
161
161
|
def is_custom_latent_space(self) -> bool:
|
|
@@ -23,8 +23,8 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
|
|
|
23
23
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
|
24
24
|
code_loader/inner_leap_binder/leapbinder.py,sha256=XLYYcV50qjMvoC1S6WW0tLBch_0g5gl1UyHiVSWYbvg,40491
|
|
25
25
|
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YlK51b5Ryo396f7tOA7Ole3vYCLs3f5ZLm2qDQ9K1NE,105781
|
|
26
|
-
code_loader/leaploader.py,sha256=
|
|
27
|
-
code_loader/leaploaderbase.py,sha256=
|
|
26
|
+
code_loader/leaploader.py,sha256=mPTi72odXIhPfYpiwGsuDklTBKxnQA9eFhGjsmuGzAU,39942
|
|
27
|
+
code_loader/leaploaderbase.py,sha256=HRmh62g7kVtiD4a8eSB3dTCSnCGz-3Uar0wVItBdhhY,6530
|
|
28
28
|
code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
|
|
29
29
|
code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
code_loader/plot_functions/plot_functions.py,sha256=6Q7VWGxetL2W0EK2QeCdObVATvBuHs3YBA09H4uoIk0,14996
|
|
@@ -32,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
|
|
|
32
32
|
code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
|
|
33
33
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
34
|
code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
|
|
35
|
-
code_loader-1.0.
|
|
36
|
-
code_loader-1.0.
|
|
37
|
-
code_loader-1.0.
|
|
38
|
-
code_loader-1.0.
|
|
35
|
+
code_loader-1.0.183.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
|
36
|
+
code_loader-1.0.183.dev0.dist-info/METADATA,sha256=iN7D5iiqQv7iPN0hABw1u6ALg4FKnFfJTZiYrgp1km0,1095
|
|
37
|
+
code_loader-1.0.183.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
38
|
+
code_loader-1.0.183.dev0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|