code-loader 1.0.179.dev2__py3-none-any.whl → 1.0.180.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataTyp
9
9
  MetricDirection, DatasetMetadataType
10
10
  from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
11
11
  LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap, LeapVideo
12
+ from code_loader.contract.sim_config import SimConfig
12
13
 
13
14
  custom_latent_space_attribute = "custom_latent_space"
14
15
 
@@ -252,6 +253,13 @@ class CustomLayerHandler:
252
253
  use_custom_latent_space: bool = False
253
254
 
254
255
 
256
+ @dataclass
257
+ class SimulationInstance:
258
+ name: str
259
+ function: Callable[..., Any]
260
+ sim_config: SimConfig
261
+
262
+
255
263
  @dataclass
256
264
  class DatasetIntegrationSetup:
257
265
  preprocess: Optional[PreprocessHandler] = None
@@ -267,6 +275,7 @@ class DatasetIntegrationSetup:
267
275
  instance_metrics: List[InstanceMetricHandler] = field(default_factory=list)
268
276
  custom_layers: Dict[str, CustomLayerHandler] = field(default_factory=dict)
269
277
  custom_latent_space: Optional[CustomLatentSpaceHandler] = None
278
+ simulations: List[SimulationInstance] = field(default_factory=list)
270
279
 
271
280
 
272
281
  @dataclass
@@ -69,6 +69,12 @@ class PredictionTypeInstance:
69
69
  channel_dim: int
70
70
 
71
71
 
72
+ @dataclass
73
+ class SimulationSetupInstance:
74
+ name: str
75
+ sim_config: Dict[str, Any]
76
+
77
+
72
78
  @dataclass
73
79
  class DatasetSetup:
74
80
  preprocess: DatasetPreprocess
@@ -79,6 +85,7 @@ class DatasetSetup:
79
85
  prediction_types: List[PredictionTypeInstance]
80
86
  custom_losses: List[CustomLossInstance]
81
87
  metrics: List[MetricInstance] = field(default_factory=list)
88
+ simulations: List[SimulationSetupInstance] = field(default_factory=list)
82
89
 
83
90
 
84
91
  @dataclass
@@ -0,0 +1,92 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Union
3
+
4
+ from code_loader.contract.enums import DatasetMetadataType
5
+
6
+
7
+ @dataclass
8
+ class FloatBounds:
9
+ min: float
10
+ max: float
11
+
12
+
13
+ @dataclass
14
+ class IntBounds:
15
+ min: int
16
+ max: int
17
+
18
+
19
+ @dataclass
20
+ class CategoricalBounds:
21
+ values: List[str]
22
+
23
+
24
+ @dataclass
25
+ class SimParamConfig:
26
+ type: DatasetMetadataType
27
+ bounds: Union[FloatBounds, IntBounds, CategoricalBounds]
28
+
29
+
30
+ SimConfig = Dict[str, SimParamConfig]
31
+
32
+ PARAM_TYPE_MAP = {
33
+ float: DatasetMetadataType.float,
34
+ int: DatasetMetadataType.int,
35
+ str: DatasetMetadataType.string,
36
+ "float": DatasetMetadataType.float,
37
+ "int": DatasetMetadataType.int,
38
+ "str": DatasetMetadataType.string,
39
+ "string": DatasetMetadataType.string,
40
+ }
41
+
42
+
43
+ def _parse_bounds(name: str, metadata_type: DatasetMetadataType, bounds_raw: Dict[str, Any]) -> Union[FloatBounds, IntBounds, CategoricalBounds]:
44
+ if metadata_type in (DatasetMetadataType.float, DatasetMetadataType.int):
45
+ if "min" not in bounds_raw or "max" not in bounds_raw:
46
+ raise ValueError(
47
+ f"Parameter '{name}' bounds must have 'min' and 'max' for numeric type."
48
+ )
49
+ if bounds_raw["min"] >= bounds_raw["max"]:
50
+ raise ValueError(
51
+ f"Parameter '{name}': min ({bounds_raw['min']}) must be < max ({bounds_raw['max']})."
52
+ )
53
+ if metadata_type == DatasetMetadataType.float:
54
+ return FloatBounds(min=float(bounds_raw["min"]), max=float(bounds_raw["max"]))
55
+ return IntBounds(min=int(bounds_raw["min"]), max=int(bounds_raw["max"]))
56
+
57
+ if metadata_type == DatasetMetadataType.string:
58
+ if "values" not in bounds_raw:
59
+ raise ValueError(
60
+ f"Parameter '{name}' bounds must have 'values' for string type."
61
+ )
62
+ if not bounds_raw["values"]:
63
+ raise ValueError(f"Parameter '{name}' 'values' must not be empty.")
64
+ return CategoricalBounds(values=list(bounds_raw["values"]))
65
+
66
+ raise ValueError(f"Parameter '{name}' has unsupported metadata type: {metadata_type}.")
67
+
68
+
69
+ def parse_sim_config(raw: Dict[str, Any]) -> SimConfig:
70
+ if not raw:
71
+ raise ValueError("sim_config must have at least one parameter.")
72
+
73
+ result: SimConfig = {}
74
+ for name, param in raw.items():
75
+ if "type" not in param:
76
+ raise ValueError(f"Parameter '{name}' missing 'type'.")
77
+ if param["type"] not in PARAM_TYPE_MAP:
78
+ raise ValueError(
79
+ f"Parameter '{name}' has unsupported type '{param['type']}'. "
80
+ f"Supported: float, int, str (or string)."
81
+ )
82
+ metadata_type = PARAM_TYPE_MAP[param["type"]]
83
+
84
+ if "bounds" not in param:
85
+ raise ValueError(f"Parameter '{name}' missing 'bounds'.")
86
+
87
+ result[name] = SimParamConfig(
88
+ type=metadata_type,
89
+ bounds=_parse_bounds(name, metadata_type, param["bounds"]),
90
+ )
91
+
92
+ return result
@@ -12,7 +12,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
12
12
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
13
13
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
14
14
  RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse, \
15
- ElementInstanceMasksHandler, InstanceCallableInterface, CustomLatentSpaceHandler, InstanceMetricHandler
15
+ ElementInstanceMasksHandler, InstanceCallableInterface, CustomLatentSpaceHandler, InstanceMetricHandler, \
16
+ SimulationInstance
16
17
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, DatasetMetadataType
17
18
  from code_loader.contract.mapping import NodeConnection, NodeMapping, NodeMappingType
18
19
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload, LeapAnalysisConfiguration
@@ -20,6 +21,7 @@ from code_loader.contract.visualizer_classes import map_leap_data_type_to_visual
20
21
  from code_loader.default_losses import loss_name_to_function
21
22
  from code_loader.default_metrics import metrics_names_to_functions_and_direction
22
23
  from code_loader.utils import to_numpy_return_wrapper, get_shape, to_numpy_return_masks_wrapper
24
+ from code_loader.contract.sim_config import parse_sim_config
23
25
  from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
24
26
  default_graph_visualizer, \
25
27
  default_image_visualizer, default_horizontal_bar_visualizer, default_word_visualizer, \
@@ -219,6 +221,17 @@ class LeapBinder:
219
221
  """
220
222
  self.setup_container.unlabeled_data_preprocess = UnlabeledDataPreprocessHandler(function)
221
223
 
224
+ def set_simulation(self, function: Callable[..., Any], name: str, sim_config_raw: Dict[str, Any]) -> None:
225
+ for sim in self.setup_container.simulations:
226
+ if sim.name == name:
227
+ raise Exception(
228
+ f"Simulation with name '{name}' already exists. Please choose another."
229
+ )
230
+ sim_config = parse_sim_config(sim_config_raw)
231
+ self.setup_container.simulations.append(
232
+ SimulationInstance(name=name, function=function, sim_config=sim_config)
233
+ )
234
+
222
235
  def set_input(self, function: SectionCallableInterface, name: str, channel_dim: int = -1) -> None:
223
236
  """
224
237
  Set the input handler function.
@@ -1,4 +1,5 @@
1
1
  # mypy: ignore-errors
2
+ import inspect
2
3
  import os
3
4
  import warnings
4
5
  import logging
@@ -28,10 +29,6 @@ from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, Le
28
29
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
29
30
  from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
30
31
 
31
- import inspect
32
- import functools
33
- from pathlib import Path
34
-
35
32
  _called_from_inside_tl_decorator = 0
36
33
  _called_from_inside_tl_integration_test_decorator = False
37
34
  _call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
@@ -1233,8 +1230,9 @@ def tensorleap_metadata(
1233
1230
  def _validate_result(result):
1234
1231
  supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
1235
1232
  np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
1236
- validate_output_structure(result, func_name=user_function.__name__,
1237
- expected_type_name=supported_result_types)
1233
+ if isinstance(result, tuple):
1234
+ validate_output_structure(result, func_name=user_function.__name__,
1235
+ expected_type_name=supported_result_types)
1238
1236
  assert isinstance(result, supported_result_types), \
1239
1237
  (f'{user_function.__name__}() validation failed: '
1240
1238
  f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
@@ -1400,6 +1398,43 @@ def tensorleap_preprocess():
1400
1398
  return decorating_function
1401
1399
 
1402
1400
 
1401
+ def tensorleap_simulation(name: str, sim_params: dict):
1402
+ def decorating_function(user_function: Callable):
1403
+ sig = inspect.signature(user_function)
1404
+ func_params = set(sig.parameters.keys())
1405
+ expected_params = set(sim_params.keys()) | {"N"}
1406
+
1407
+ missing = expected_params - func_params
1408
+ if missing:
1409
+ raise Exception(
1410
+ f"{user_function.__name__}() registration failed: "
1411
+ f"Missing required parameters: {missing}. "
1412
+ f"Function must accept all sim_params params plus 'N'."
1413
+ )
1414
+
1415
+ extra = func_params - expected_params - {"seed"}
1416
+ if extra:
1417
+ raise Exception(
1418
+ f"{user_function.__name__}() registration failed: "
1419
+ f"Unexpected parameters: {extra}. "
1420
+ f"Function must only accept sim_params params plus 'N' and optionally 'seed'."
1421
+ )
1422
+
1423
+ leap_binder.set_simulation(user_function, name, sim_params)
1424
+
1425
+ def inner(*args, **kwargs):
1426
+ result = user_function(*args, **kwargs)
1427
+ assert isinstance(result, PreprocessResponse), (
1428
+ f"{user_function.__name__}() validation failed: "
1429
+ f"Expected return type PreprocessResponse. Got {type(result).__name__}."
1430
+ )
1431
+ return result
1432
+
1433
+ return inner
1434
+
1435
+ return decorating_function
1436
+
1437
+
1403
1438
  def tensorleap_element_instance_preprocess(
1404
1439
  instance_length_encoder: InstanceLengthCallableInterface, instance_mask_encoder: InstanceCallableInterface):
1405
1440
  def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
code_loader/leaploader.py CHANGED
@@ -22,13 +22,22 @@ from code_loader.contract.exceptions import DatasetScriptException
22
22
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
23
23
  DatasetPreprocess, DatasetSetup, DatasetInputInstance, DatasetOutputInstance, DatasetMetadataInstance, \
24
24
  VisualizerInstance, PredictionTypeInstance, ModelSetup, CustomLayerInstance, MetricInstance, CustomLossInstance, \
25
- EngineFileContract
25
+ EngineFileContract, SimulationSetupInstance
26
+ from code_loader.contract.sim_config import FloatBounds, IntBounds, CategoricalBounds
26
27
  from code_loader.inner_leap_binder import global_leap_binder
27
28
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
28
29
  from code_loader.leaploaderbase import LeapLoaderBase
29
30
  from code_loader.utils import get_root_exception_file_and_line_number, get_metadata_type_from_variable
30
31
 
31
32
 
33
+ def _serialize_sim_bounds(bounds) -> dict:
34
+ if isinstance(bounds, (FloatBounds, IntBounds)):
35
+ return {"min": bounds.min, "max": bounds.max}
36
+ if isinstance(bounds, CategoricalBounds):
37
+ return {"values": bounds.values}
38
+ raise ValueError(f"Unknown bounds type: {type(bounds)}")
39
+
40
+
32
41
  class LeapLoader(LeapLoaderBase):
33
42
  def __init__(self, code_path: str, code_entry_name: str):
34
43
  super().__init__(code_path, code_entry_name)
@@ -454,9 +463,20 @@ class LeapLoader(LeapLoaderBase):
454
463
  metric_inst = MetricInstance(metric.metric_handler_data.name, metric.metric_handler_data.arg_names)
455
464
  metrics.append(metric_inst)
456
465
 
466
+ simulations = []
467
+ for sim in setup.simulations:
468
+ sim_config_serialized = {
469
+ name: {
470
+ "type": param.type.value,
471
+ "bounds": _serialize_sim_bounds(param.bounds),
472
+ }
473
+ for name, param in sim.sim_config.items()
474
+ }
475
+ simulations.append(SimulationSetupInstance(name=sim.name, sim_config=sim_config_serialized))
476
+
457
477
  return DatasetSetup(preprocess=dataset_preprocess, inputs=inputs, outputs=ground_truths,
458
478
  metadata=metadata_instances, visualizers=visualizers, prediction_types=prediction_types,
459
- custom_losses=custom_losses, metrics=metrics)
479
+ custom_losses=custom_losses, metrics=metrics, simulations=simulations)
460
480
 
461
481
  def get_model_setup_response(self) -> ModelSetup:
462
482
  setup = global_leap_binder.setup_container
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.179.dev2
3
+ Version: 1.0.180.dev0
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,11 +1,12 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=outxRQ0M-zMfV0QGVJmAed5qWfRmyD0TV6-goEGAzBw,406
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=nQYwxfGpeb60AGYlHwJVV4qPbAE2vEnYO0d_dzvnp5Q,10039
4
+ code_loader/contract/datasetclasses.py,sha256=hMovDpEEl-07lPKV8fp52HYd0yK6trkta7oTZ8cKU6I,10277
5
5
  code_loader/contract/enums.py,sha256=2q-IV_5g9lLE306DIbWA1c0tn5IhDtxsKxyV1x_Lreg,1671
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/mapping.py,sha256=sWJhpng-IkOzQnWQdMT5w2ZZ3X1Z_OOzSwCLXIS7oxE,1446
8
- code_loader/contract/responsedataclasses.py,sha256=FE04qgm4gz30RWbB_6UKeyN6BHyg9F1Xjn5VRNTUSp8,4769
8
+ code_loader/contract/responsedataclasses.py,sha256=G37OBHmb-b9-0gc5Hs69t8jjjRtEzmWP389Y9IWWlGQ,4935
9
+ code_loader/contract/sim_config.py,sha256=tWgM3n2_UTP8HnWBUFYLY928ya1ezFWhMOOU9c3WsHU,3002
9
10
  code_loader/contract/visualizer_classes.py,sha256=Wz9eItmoRaKEHa3p0aW0Ypxx4_xUmaZyLBznnTuxwi0,15425
10
11
  code_loader/default_losses.py,sha256=NoOQym1106bDN5dcIk56Elr7ZG5quUHArqfP5-Nyxyo,1139
11
12
  code_loader/default_metrics.py,sha256=2XSlyNw_XLDGSJDoz5W_Evi5wbL0dhwq24pPr15vSPc,5025
@@ -20,9 +21,9 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
20
21
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
21
22
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
22
23
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
23
- code_loader/inner_leap_binder/leapbinder.py,sha256=Xrwx0Kptcz2dmoHoZQ7kyyNmitsvAKoTBngIMBvh--I,37859
24
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=M4DQKAI3PnxzxdjYkssOtbfqMvIT08CJS7LpKbzl4Uw,100896
25
- code_loader/leaploader.py,sha256=VSSxsU2KxHPo5SzN1bJtRXaz29IE0XNFxgrnAJyG3io,32162
24
+ code_loader/inner_leap_binder/leapbinder.py,sha256=07_ZIjRLKuERF12fo3IohMPNjfDUeZQxjk1vJMGtd1A,38484
25
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=YM-gyBi2wHaoyUzhpPtd5-96wAdeKohCrEiSB2honMQ,102282
26
+ code_loader/leaploader.py,sha256=-AnZ_-_FexkW8VRyC2njK-aoE_x0Ud2RRxeZL4OHt5A,33032
26
27
  code_loader/leaploaderbase.py,sha256=NJGaas8S6JeHYSsKkMSutyfcSKdK9jXTic7BcjC5uNc,6303
27
28
  code_loader/mixpanel_tracker.py,sha256=U7eUGrPjc-2rgFG7isqosf65tKZkotQ0XKuAML_lIjA,9067
28
29
  code_loader/plot_functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -31,7 +32,7 @@ code_loader/plot_functions/visualize.py,sha256=gsBAYYkwMh7jIpJeDMPS8G4CW-pxwx6Lz
31
32
  code_loader/utils.py,sha256=YecipkdTA-VcE9F0RQcY9cFnY8P3AksPnHM2Db7xUSk,3972
32
33
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
34
  code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
34
- code_loader-1.0.179.dev2.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
35
- code_loader-1.0.179.dev2.dist-info/METADATA,sha256=SwqY8uUIwvKGedEGMdZZMUkMrZiVnrTn09nPGHy4D0c,1095
36
- code_loader-1.0.179.dev2.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
37
- code_loader-1.0.179.dev2.dist-info/RECORD,,
35
+ code_loader-1.0.180.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
36
+ code_loader-1.0.180.dev0.dist-info/METADATA,sha256=RM5e7ncL3g_HCtahcJcAVdJPoztrnhuS98Jszze_Jpw,1095
37
+ code_loader-1.0.180.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
38
+ code_loader-1.0.180.dev0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.1
2
+ Generator: poetry-core 1.9.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any