code-loader 1.0.182.dev0__py3-none-any.whl → 1.0.183.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
code_loader/leaploader.py CHANGED
@@ -44,6 +44,7 @@ class LeapLoader(LeapLoaderBase):
44
44
  super().__init__(code_path, code_entry_name)
45
45
 
46
46
  self._preprocess_result_cached = None
47
+ self._synthetic_responses: Dict[str, PreprocessResponse] = {}
47
48
 
48
49
  try:
49
50
  from code_loader.mixpanel_tracker import track_code_loader_loaded
@@ -189,6 +190,24 @@ class LeapLoader(LeapLoaderBase):
189
190
 
190
191
  def get_sample(self, state: DataStateEnum, sample_id: Union[int, str], instance_id: int = None) -> DatasetSample:
191
192
  self.exec_script()
193
+
194
+ if isinstance(sample_id, str) and sample_id.startswith("synthetic_"):
195
+ last_underscore = sample_id.rfind("_")
196
+ if last_underscore != -1:
197
+ prefix = sample_id[:last_underscore]
198
+ local_idx_str = sample_id[last_underscore + 1:]
199
+ sim_preprocess = self._synthetic_responses.get(prefix)
200
+ if sim_preprocess is not None:
201
+ try:
202
+ local_idx = int(local_idx_str)
203
+ original_sample_id = sim_preprocess.sample_ids[local_idx]
204
+ except (ValueError, IndexError):
205
+ pass
206
+ else:
207
+ return self._get_sample_from_preprocess(
208
+ sim_preprocess, original_sample_id, synthetic_index=sample_id
209
+ )
210
+
192
211
  preprocess_result = self._preprocess_result()
193
212
  if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
194
213
  self._preprocess_result(update_unlabeled_preprocess=True)
@@ -211,6 +230,50 @@ class LeapLoader(LeapLoaderBase):
211
230
  instance_masks=instance_mask)
212
231
  return sample
213
232
 
233
+ def _get_sample_from_preprocess(
234
+ self,
235
+ preprocess: "PreprocessResponse",
236
+ original_sample_id: Union[int, str],
237
+ synthetic_index: str,
238
+ ) -> DatasetSample:
239
+ inputs = {}
240
+ for handler in global_leap_binder.setup_container.inputs:
241
+ inputs[handler.name] = handler.function(original_sample_id, preprocess)
242
+
243
+ gt = {}
244
+ for handler in global_leap_binder.setup_container.ground_truths:
245
+ gt[handler.name] = handler.function(original_sample_id, preprocess)
246
+
247
+ metadata = {}
248
+ metadata_is_none = {}
249
+ for handler in global_leap_binder.setup_container.metadata:
250
+ handler_result = handler.function(original_sample_id, preprocess)
251
+ if isinstance(handler_result, dict):
252
+ for k, v in handler_result.items():
253
+ key = "{}_{}".format(handler.name, k)
254
+ metadata[key], metadata_is_none[key] = self._convert_metadata_to_correct_type(key, v)
255
+ else:
256
+ metadata[handler.name], metadata_is_none[handler.name] = self._convert_metadata_to_correct_type(
257
+ handler.name, handler_result
258
+ )
259
+
260
+ custom_latent_space = None
261
+ if global_leap_binder.setup_container.custom_latent_space is not None:
262
+ custom_latent_space = global_leap_binder.setup_container.custom_latent_space.function(
263
+ original_sample_id, preprocess
264
+ )
265
+
266
+ return DatasetSample(
267
+ inputs=inputs,
268
+ gt=gt if gt else None,
269
+ metadata=metadata,
270
+ metadata_is_none=metadata_is_none,
271
+ index=synthetic_index,
272
+ state=DataStateEnum.additional,
273
+ custom_latent_space=custom_latent_space,
274
+ instance_masks=None,
275
+ )
276
+
214
277
  def check_dataset(self) -> DatasetIntegParseResult:
215
278
  test_payloads: List[DatasetTestResultPayload] = []
216
279
  setup_response = None
@@ -345,8 +408,8 @@ class LeapLoader(LeapLoaderBase):
345
408
  result_payloads.append(test_result)
346
409
  return result_payloads
347
410
 
348
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
349
- # type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str, npt.NDArray[np.float32]]
411
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
412
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
350
413
  self.exec_script()
351
414
  sim = next(
352
415
  (s for s in global_leap_binder.setup_container.simulations if s.name == sim_name),
@@ -364,11 +427,18 @@ class LeapLoader(LeapLoaderBase):
364
427
  _simulation_context["active"] = False
365
428
  sim_preprocess.state = DataStateType.additional
366
429
  sim_preprocess.tl_generated = True
430
+ original_sample_ids = list(sim_preprocess.sample_ids)
367
431
  per_encoder = {handler.name: [] for handler in global_leap_binder.setup_container.inputs}
368
- for sample_id in sim_preprocess.sample_ids:
432
+ for sample_id in original_sample_ids:
369
433
  for handler in global_leap_binder.setup_container.inputs:
370
434
  per_encoder[handler.name].append(handler.function(sample_id, sim_preprocess))
371
- return {name: np.stack(arrays) for name, arrays in per_encoder.items()}
435
+ encoded = {name: np.stack(arrays) for name, arrays in per_encoder.items()}
436
+ if sample_id_prefix is not None:
437
+ returned_sample_ids = ["{}_{}".format(sample_id_prefix, i) for i in range(len(original_sample_ids))]
438
+ self._synthetic_responses[sample_id_prefix] = sim_preprocess
439
+ else:
440
+ returned_sample_ids = original_sample_ids
441
+ return {"encoded": encoded, "sample_ids": returned_sample_ids}
372
442
 
373
443
  @staticmethod
374
444
  def _get_all_dataset_base_handlers() -> List[Union[DatasetBaseHandler, MetadataHandler]]:
@@ -154,8 +154,8 @@ class LeapLoaderBase:
154
154
  pass
155
155
 
156
156
  @abstractmethod
157
- def run_simulation(self, sim_name, params=None, n_samples=1, seed=0):
158
- # type: (str, Optional[Dict[str, Any]], int, int) -> Dict[str, Any]
157
+ def run_simulation(self, sim_name, params=None, n_samples=1, seed=0, sample_id_prefix=None):
158
+ # type: (str, Optional[Dict[str, Any]], int, int, Optional[str]) -> Dict[str, Any]
159
159
  pass
160
160
 
161
161
  def is_custom_latent_space(self) -> bool:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.182.dev0
3
+ Version: 1.0.183.dev0
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -23,8 +23,8 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
23
23
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
24
24
  code_loader/inner_leap_binder/leapbinder.py,sha256=XLYYcV50qjMvoC1S6WW0tLBch_0g5gl1UyHiVSWYbvg,40491
25
25
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YlK51b5Ryo396f7tOA7Ole3vYCLs3f5ZLm2qDQ9K1NE,105781
26
- code_loader/leaploader.py,sha256=Md-CsdLiBVfeGIlVqpVjNVF_fEAFbg2ahmVCCPl3iOw,36758
27
- code_loader/leaploaderbase.py,sha256=HrOGX9H8eRbyrFOEdvC2OISJvtZWKMVDah0cry8PXlo,6492
26
+ code_loader/leaploader.py,sha256=mPTi72odXIhPfYpiwGsuDklTBKxnQA9eFhGjsmuGzAU,39942
27
+ code_loader/leaploaderbase.py,sha256=HRmh62g7kVtiD4a8eSB3dTCSnCGz-3Uar0wVItBdhhY,6530
28
28
  code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
29
29
  code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  code_loader/plot_functions/plot_functions.py,sha256=6Q7VWGxetL2W0EK2QeCdObVATvBuHs3YBA09H4uoIk0,14996
@@ -32,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
32
32
  code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
33
33
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
35
- code_loader-1.0.182.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
- code_loader-1.0.182.dev0.dist-info/METADATA,sha256=qmRxIYq5blbCPlBU8gavN-RNdC8fK3pxS-ckMVME2fI,1095
37
- code_loader-1.0.182.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
- code_loader-1.0.182.dev0.dist-info/RECORD,,
35
+ code_loader-1.0.183.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
+ code_loader-1.0.183.dev0.dist-info/METADATA,sha256=iN7D5iiqQv7iPN0hABw1u6ALg4FKnFfJTZiYrgp1km0,1095
37
+ code_loader-1.0.183.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
+ code_loader-1.0.183.dev0.dist-info/RECORD,,