code-loader 1.0.180.dev1__py3-none-any.whl → 1.0.180.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -42,6 +42,7 @@ class PreprocessResponse:
42
42
  sample_id_type: Optional[Union[Type[str], Type[int]]] = None
43
43
  sample_ids_to_instance_mappings: Optional[Dict[str, List[str]]] = None # in use only for element instance
44
44
  instance_to_sample_ids_mappings: Optional[Dict[str, str]] = None # in use only for element instance
45
+ tl_generated: bool = False
45
46
 
46
47
  def __post_init__(self) -> None:
47
48
  assert self.sample_ids_to_instance_mappings is None, f"Keep sample_ids_to_instance_mappings None when initializing PreprocessResponse"
@@ -254,9 +255,9 @@ class CustomLayerHandler:
254
255
 
255
256
 
256
257
  @dataclass
257
- class SimulationInstance:
258
+ class SimulationHandler:
258
259
  name: str
259
- function: Callable[..., Any]
260
+ function: Callable[..., PreprocessResponse]
260
261
  sim_config: SimConfig
261
262
 
262
263
 
@@ -275,7 +276,7 @@ class DatasetIntegrationSetup:
275
276
  instance_metrics: List[InstanceMetricHandler] = field(default_factory=list)
276
277
  custom_layers: Dict[str, CustomLayerHandler] = field(default_factory=dict)
277
278
  custom_latent_space: Optional[CustomLatentSpaceHandler] = None
278
- simulations: List[SimulationInstance] = field(default_factory=list)
279
+ simulations: List[SimulationHandler] = field(default_factory=list)
279
280
 
280
281
 
281
282
  @dataclass
@@ -70,7 +70,7 @@ class PredictionTypeInstance:
70
70
 
71
71
 
72
72
  @dataclass
73
- class SimulationSetupInstance:
73
+ class SimulationInstance:
74
74
  name: str
75
75
  sim_config: Dict[str, Any]
76
76
 
@@ -85,7 +85,7 @@ class DatasetSetup:
85
85
  prediction_types: List[PredictionTypeInstance]
86
86
  custom_losses: List[CustomLossInstance]
87
87
  metrics: List[MetricInstance] = field(default_factory=list)
88
- simulations: List[SimulationSetupInstance] = field(default_factory=list)
88
+ simulations: List[SimulationInstance] = field(default_factory=list)
89
89
 
90
90
 
91
91
  @dataclass
@@ -13,7 +13,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
13
13
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
14
14
  RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse, \
15
15
  ElementInstanceMasksHandler, InstanceCallableInterface, CustomLatentSpaceHandler, InstanceMetricHandler, \
16
- SimulationInstance
16
+ SimulationHandler
17
17
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, DatasetMetadataType
18
18
  from code_loader.contract.mapping import NodeConnection, NodeMapping, NodeMappingType
19
19
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload, LeapAnalysisConfiguration
@@ -221,7 +221,7 @@ class LeapBinder:
221
221
  """
222
222
  self.setup_container.unlabeled_data_preprocess = UnlabeledDataPreprocessHandler(function)
223
223
 
224
- def set_simulation(self, function: Callable[..., Any], name: str, sim_config_raw: Dict[str, Any]) -> None:
224
+ def set_simulation(self, function: Callable[..., PreprocessResponse], name: str, sim_config_raw: Dict[str, Any]) -> None:
225
225
  for sim in self.setup_container.simulations:
226
226
  if sim.name == name:
227
227
  raise Exception(
@@ -229,7 +229,7 @@ class LeapBinder:
229
229
  )
230
230
  sim_config = parse_sim_config(sim_config_raw)
231
231
  self.setup_container.simulations.append(
232
- SimulationInstance(name=name, function=function, sim_config=sim_config)
232
+ SimulationHandler(name=name, function=function, sim_config=sim_config)
233
233
  )
234
234
 
235
235
  def set_input(self, function: SectionCallableInterface, name: str, channel_dim: int = -1) -> None:
@@ -1400,7 +1400,7 @@ def tensorleap_preprocess():
1400
1400
 
1401
1401
 
1402
1402
  def tensorleap_simulation(name: str, sim_params: dict):
1403
- def decorating_function(user_function: Callable):
1403
+ def decorating_function(user_function: Callable[..., PreprocessResponse]):
1404
1404
  sig = inspect.signature(user_function)
1405
1405
  func_params = set(sig.parameters.keys())
1406
1406
  expected_params = set(sim_params.keys()) | {"N"}
@@ -1423,12 +1423,47 @@ def tensorleap_simulation(name: str, sim_params: dict):
1423
1423
 
1424
1424
  leap_binder.set_simulation(user_function, name, sim_params)
1425
1425
 
1426
- def inner(*args, **kwargs):
1427
- result = user_function(*args, **kwargs)
1426
+ if not _call_from_tl_platform:
1427
+ add_table_row(f"tensorleap_simulation:{name} (optional)")
1428
+
1429
+ def _validate_input_args(*args, **kwargs):
1430
+ assert len(args) == 0, (
1431
+ f"{user_function.__name__}() validation failed: "
1432
+ f"Simulation functions must be called with keyword arguments only. Got positional args: {args}."
1433
+ )
1434
+ missing = expected_params - set(kwargs.keys())
1435
+ assert not missing, (
1436
+ f"{user_function.__name__}() validation failed: "
1437
+ f"Missing required keyword arguments: {missing}."
1438
+ )
1439
+
1440
+ def _validate_result(result):
1428
1441
  assert isinstance(result, PreprocessResponse), (
1429
1442
  f"{user_function.__name__}() validation failed: "
1430
1443
  f"Expected return type PreprocessResponse. Got {type(result).__name__}."
1431
1444
  )
1445
+ assert result.state == DataStateType.additional, (
1446
+ f"{user_function.__name__}() validation failed: "
1447
+ f"Simulation must return a PreprocessResponse with state=DataStateType.additional. "
1448
+ f"Got state={result.state!r}."
1449
+ )
1450
+
1451
+ def inner(*args, **kwargs):
1452
+ if not _call_from_tl_platform:
1453
+ set_current('tensorleap_simulation')
1454
+ _validate_input_args(*args, **kwargs)
1455
+ result = user_function(*args, **kwargs)
1456
+ _validate_result(result)
1457
+ result.tl_generated = True
1458
+ if not _call_from_tl_platform:
1459
+ update_env_params_func(f"tensorleap_simulation:{name}", "v")
1460
+ try:
1461
+ emit_integration_event_once(AnalyticsEvent.SIMULATION_INTEGRATION_TEST, {
1462
+ 'simulation_name': name,
1463
+ 'sim_params_count': len(sim_params)
1464
+ })
1465
+ except Exception as e:
1466
+ logger.debug(f"Failed to emit simulation integration test event: {e}")
1432
1467
  return result
1433
1468
 
1434
1469
  return inner
@@ -2129,14 +2164,17 @@ def tensorleap_status_table():
2129
2164
  traceback.print_exception(exc_type, exc_value, exc_traceback)
2130
2165
  run_on_exit()
2131
2166
 
2167
+ def add_table_row(name: str):
2168
+ table.append({"name": name, "Added to integration": UNKNOWN})
2169
+
2132
2170
  atexit.register(run_on_exit)
2133
2171
  sys.excepthook = handle_exception
2134
2172
 
2135
- return set_current, update_env_params
2173
+ return set_current, update_env_params, add_table_row
2136
2174
 
2137
2175
 
2138
2176
  if not _call_from_tl_platform:
2139
- set_current, update_env_params_func = tensorleap_status_table()
2177
+ set_current, update_env_params_func, add_table_row = tensorleap_status_table()
2140
2178
 
2141
2179
 
2142
2180
 
code_loader/leaploader.py CHANGED
@@ -22,7 +22,7 @@ from code_loader.contract.exceptions import DatasetScriptException
22
22
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
23
23
  DatasetPreprocess, DatasetSetup, DatasetInputInstance, DatasetOutputInstance, DatasetMetadataInstance, \
24
24
  VisualizerInstance, PredictionTypeInstance, ModelSetup, CustomLayerInstance, MetricInstance, CustomLossInstance, \
25
- EngineFileContract, SimulationSetupInstance
25
+ EngineFileContract, SimulationInstance
26
26
  from code_loader.contract.sim_config import FloatBounds, IntBounds, CategoricalBounds
27
27
  from code_loader.inner_leap_binder import global_leap_binder
28
28
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
@@ -472,7 +472,7 @@ class LeapLoader(LeapLoaderBase):
472
472
  }
473
473
  for name, param in sim.sim_config.items()
474
474
  }
475
- simulations.append(SimulationSetupInstance(name=sim.name, sim_config=sim_config_serialized))
475
+ simulations.append(SimulationInstance(name=sim.name, sim_config=sim_config_serialized))
476
476
 
477
477
  return DatasetSetup(preprocess=dataset_preprocess, inputs=inputs, outputs=ground_truths,
478
478
  metadata=metadata_instances, visualizers=visualizers, prediction_types=prediction_types,
@@ -23,6 +23,7 @@ class AnalyticsEvent(str, Enum):
23
23
  PREPROCESS_INTEGRATION_TEST = "preprocess_integration_test"
24
24
  INPUT_ENCODER_INTEGRATION_TEST = "input_encoder_integration_test"
25
25
  GT_ENCODER_INTEGRATION_TEST = "gt_encoder_integration_test"
26
+ SIMULATION_INTEGRATION_TEST = "simulation_integration_test"
26
27
 
27
28
 
28
29
  class CodeLoaderLoadedProps(TypedDict, total=False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.180.dev1
3
+ Version: 1.0.180.dev3
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,11 +1,11 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=outxRQ0M-zMfV0QGVJmAed5qWfRmyD0TV6-goEGAzBw,406
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=hMovDpEEl-07lPKV8fp52HYd0yK6trkta7oTZ8cKU6I,10277
4
+ code_loader/contract/datasetclasses.py,sha256=KefWGK04SDGQe4KjwBCnP14LO7sO7esS5vYmevFZW_A,10321
5
5
  code_loader/contract/enums.py,sha256=2q-IV_5g9lLE306DIbWA1c0tn5IhDtxsKxyV1x_Lreg,1671
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/mapping.py,sha256=sWJhpng-IkOzQnWQdMT5w2ZZ3X1Z_OOzSwCLXIS7oxE,1446
8
- code_loader/contract/responsedataclasses.py,sha256=G37OBHmb-b9-0gc5Hs69t8jjjRtEzmWP389Y9IWWlGQ,4935
8
+ code_loader/contract/responsedataclasses.py,sha256=5VFgGjRubMW8ItMPils3rkBNejunCGLaa192AIi-xko,4925
9
9
  code_loader/contract/sim_config.py,sha256=tWgM3n2_UTP8HnWBUFYLY928ya1ezFWhMOOU9c3WsHU,3002
10
10
  code_loader/contract/visualizer_classes.py,sha256=Wz9eItmoRaKEHa3p0aW0Ypxx4_xUmaZyLBznnTuxwi0,15425
11
11
  code_loader/default_losses.py,sha256=NoOQym1106bDN5dcIk56Elr7ZG5quUHArqfP5-Nyxyo,1139
@@ -21,18 +21,18 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
21
21
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
22
22
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
23
23
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
24
- code_loader/inner_leap_binder/leapbinder.py,sha256=07_ZIjRLKuERF12fo3IohMPNjfDUeZQxjk1vJMGtd1A,38484
25
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=DTvxzqe-XI4oDS62ck7MNeBmt0-FaGbm6Rpva-idd1U,102299
26
- code_loader/leaploader.py,sha256=-AnZ_-_FexkW8VRyC2njK-aoE_x0Ud2RRxeZL4OHt5A,33032
24
+ code_loader/inner_leap_binder/leapbinder.py,sha256=Avt4lQv4mehoEOlH7UyF7Wna_GGJWSAF6r55qUk1MYE,38497
25
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=QoWaNdHdh3XQDYXRqvnSS0S3FzKWB-6jbsE6Mq6hl0U,104111
26
+ code_loader/leaploader.py,sha256=vz812i2a0-9UTqsxNx9z0lqJlLt4BWGsPwl4E4lWjZ4,33022
27
27
  code_loader/leaploaderbase.py,sha256=NJGaas8S6JeHYSsKkMSutyfcSKdK9jXTic7BcjC5uNc,6303
28
- code_loader/mixpanel_tracker.py,sha256=U7eUGrPjc-2rgFG7isqosf65tKZkotQ0XKuAML_lIjA,9067
28
+ code_loader/mixpanel_tracker.py,sha256=rNwRmFifNbdUoqLQvvhhgpKczWpWiEmd8MfyJe27sxw,9131
29
29
  code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  code_loader/plot_functions/plot_functions.py,sha256=6Q7VWGxetL2W0EK2QeCdObVATvBuHs3YBA09H4uoIk0,14996
31
31
  code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6LznoQIvi4vpo,657
32
32
  code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
33
33
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
35
- code_loader-1.0.180.dev1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
- code_loader-1.0.180.dev1.dist-info/METADATA,sha256=yCrmwdjI8kBVFXYWbtmaoc0zwyo1Xh8MDwc7Vp1a--c,1095
37
- code_loader-1.0.180.dev1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
- code_loader-1.0.180.dev1.dist-info/RECORD,,
35
+ code_loader-1.0.180.dev3.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
+ code_loader-1.0.180.dev3.dist-info/METADATA,sha256=e742DymYsZVynEaCgd7V8Fsx4hyNnJJO7aM0mfB55Ls,1095
37
+ code_loader-1.0.180.dev3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
+ code_loader-1.0.180.dev3.dist-info/RECORD,,