humalab 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of humalab might be problematic. Click here for more details.
- humalab-0.0.6/.github/pull_request_template.md +8 -0
- {humalab-0.0.4/humalab.egg-info → humalab-0.0.6}/PKG-INFO +1 -1
- humalab-0.0.6/VERSION +1 -0
- humalab-0.0.6/humalab/__init__.py +20 -0
- humalab-0.0.6/humalab/assets/__init__.py +4 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/assets/files/resource_file.py +29 -3
- {humalab-0.0.4 → humalab-0.0.6}/humalab/assets/files/urdf_file.py +14 -10
- humalab-0.0.6/humalab/assets/resource_operator.py +91 -0
- humalab-0.0.6/humalab/constants.py +41 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/bernoulli.py +16 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/categorical.py +4 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/discrete.py +22 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/gaussian.py +22 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/log_uniform.py +22 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/truncated_gaussian.py +36 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/uniform.py +22 -0
- humalab-0.0.6/humalab/episode.py +196 -0
- humalab-0.0.6/humalab/humalab.py +180 -0
- humalab-0.0.6/humalab/humalab_api_client.py +971 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/humalab_config.py +0 -13
- {humalab-0.0.4 → humalab-0.0.6}/humalab/humalab_test.py +46 -29
- humalab-0.0.6/humalab/metrics/__init__.py +11 -0
- humalab-0.0.6/humalab/metrics/code.py +28 -0
- humalab-0.0.6/humalab/metrics/metric.py +62 -0
- humalab-0.0.6/humalab/metrics/scenario_stats.py +95 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/metrics/summary.py +24 -18
- humalab-0.0.6/humalab/run.py +279 -0
- humalab-0.0.6/humalab/scenarios/__init__.py +4 -0
- humalab-0.0.6/humalab/scenarios/scenario.py +372 -0
- humalab-0.0.6/humalab/scenarios/scenario_operator.py +82 -0
- {humalab-0.0.4/humalab → humalab-0.0.6/humalab/scenarios}/scenario_test.py +150 -269
- humalab-0.0.6/humalab/utils.py +37 -0
- {humalab-0.0.4 → humalab-0.0.6/humalab.egg-info}/PKG-INFO +1 -1
- {humalab-0.0.4 → humalab-0.0.6}/humalab.egg-info/SOURCES.txt +11 -5
- {humalab-0.0.4 → humalab-0.0.6}/pyproject.toml +1 -1
- humalab-0.0.4/VERSION +0 -1
- humalab-0.0.4/humalab/__init__.py +0 -9
- humalab-0.0.4/humalab/assets/__init__.py +0 -4
- humalab-0.0.4/humalab/assets/resource_manager.py +0 -57
- humalab-0.0.4/humalab/constants.py +0 -7
- humalab-0.0.4/humalab/humalab.py +0 -217
- humalab-0.0.4/humalab/humalab_api_client.py +0 -273
- humalab-0.0.4/humalab/metrics/__init__.py +0 -11
- humalab-0.0.4/humalab/metrics/dist_metric.py +0 -22
- humalab-0.0.4/humalab/metrics/metric.py +0 -129
- humalab-0.0.4/humalab/run.py +0 -214
- humalab-0.0.4/humalab/scenario.py +0 -225
- {humalab-0.0.4 → humalab-0.0.6}/.gitignore +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/LICENSE +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/Makefile +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/README.md +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/build.sh +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/assets/archive.py +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/assets/files/__init__.py +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/__init__.py +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab/dists/distribution.py +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab.egg-info/dependency_links.txt +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab.egg-info/entry_points.txt +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab.egg-info/not-zip-safe +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab.egg-info/requires.txt +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/humalab.egg-info/top_level.txt +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/requirements-dev.txt +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/requirements.txt +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/setup.cfg +0 -0
- {humalab-0.0.4 → humalab-0.0.6}/setup.py +0 -0
humalab-0.0.6/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.6
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from humalab.humalab import init, finish, login
|
|
2
|
+
from humalab import assets
|
|
3
|
+
from humalab import metrics
|
|
4
|
+
from humalab import scenarios
|
|
5
|
+
from humalab.run import Run
|
|
6
|
+
from humalab.constants import MetricDimType, GraphType
|
|
7
|
+
# from humalab import evaluators
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"init",
|
|
11
|
+
"finish",
|
|
12
|
+
"login",
|
|
13
|
+
"assets",
|
|
14
|
+
"metrics",
|
|
15
|
+
"scenarios",
|
|
16
|
+
"Run",
|
|
17
|
+
"MetricDimType",
|
|
18
|
+
"GraphType",
|
|
19
|
+
# "evaluators",
|
|
20
|
+
]
|
|
@@ -1,4 +1,17 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from humalab.constants import DEFAULT_PROJECT
|
|
5
|
+
|
|
6
|
+
class ResourceType(Enum):
|
|
7
|
+
URDF = "urdf"
|
|
8
|
+
MJCF = "mjcf"
|
|
9
|
+
USD = "usd"
|
|
10
|
+
MESH = "mesh"
|
|
11
|
+
VIDEO = "video"
|
|
12
|
+
IMAGE = "image"
|
|
13
|
+
DATA = "data"
|
|
14
|
+
|
|
2
15
|
|
|
3
16
|
|
|
4
17
|
class ResourceFile:
|
|
@@ -6,16 +19,22 @@ class ResourceFile:
|
|
|
6
19
|
name: str,
|
|
7
20
|
version: int,
|
|
8
21
|
filename: str,
|
|
9
|
-
resource_type: str,
|
|
22
|
+
resource_type: str | ResourceType,
|
|
23
|
+
project: str = DEFAULT_PROJECT,
|
|
10
24
|
description: str | None = None,
|
|
11
25
|
created_at: datetime | None = None):
|
|
26
|
+
self._project = project
|
|
12
27
|
self._name = name
|
|
13
28
|
self._version = version
|
|
14
29
|
self._filename = filename
|
|
15
|
-
self._resource_type = resource_type
|
|
30
|
+
self._resource_type = ResourceType(resource_type)
|
|
16
31
|
self._description = description
|
|
17
32
|
self._created_at = created_at
|
|
18
33
|
|
|
34
|
+
@property
|
|
35
|
+
def project(self) -> str:
|
|
36
|
+
return self._project
|
|
37
|
+
|
|
19
38
|
@property
|
|
20
39
|
def name(self) -> str:
|
|
21
40
|
return self._name
|
|
@@ -29,7 +48,7 @@ class ResourceFile:
|
|
|
29
48
|
return self._filename
|
|
30
49
|
|
|
31
50
|
@property
|
|
32
|
-
def resource_type(self) ->
|
|
51
|
+
def resource_type(self) -> ResourceType:
|
|
33
52
|
return self._resource_type
|
|
34
53
|
|
|
35
54
|
@property
|
|
@@ -39,3 +58,10 @@ class ResourceFile:
|
|
|
39
58
|
@property
|
|
40
59
|
def description(self) -> str | None:
|
|
41
60
|
return self._description
|
|
61
|
+
|
|
62
|
+
def __repr__(self) -> str:
|
|
63
|
+
return f"ResourceFile(project={self._project}, name={self._name}, version={self._version}, filename={self._filename}, resource_type={self._resource_type}, description={self._description}, created_at={self._created_at})"
|
|
64
|
+
|
|
65
|
+
def __str__(self) -> str:
|
|
66
|
+
return self.__repr__()
|
|
67
|
+
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
1
|
import os
|
|
3
2
|
import glob
|
|
4
|
-
from
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
|
|
5
|
+
from humalab.assets.files.resource_file import ResourceFile, ResourceType
|
|
5
6
|
from humalab.assets.archive import extract_archive
|
|
7
|
+
from humalab.constants import DEFAULT_PROJECT
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class URDFFile(ResourceFile):
|
|
@@ -10,30 +12,32 @@ class URDFFile(ResourceFile):
|
|
|
10
12
|
name: str,
|
|
11
13
|
version: int,
|
|
12
14
|
filename: str,
|
|
15
|
+
project: str = DEFAULT_PROJECT,
|
|
13
16
|
urdf_filename: str | None = None,
|
|
14
17
|
description: str | None = None,
|
|
15
18
|
created_at: datetime | None = None,):
|
|
16
|
-
super().__init__(
|
|
19
|
+
super().__init__(project=project,
|
|
20
|
+
name=name,
|
|
17
21
|
version=version,
|
|
18
22
|
description=description,
|
|
19
23
|
filename=filename,
|
|
20
|
-
resource_type=
|
|
24
|
+
resource_type=ResourceType.URDF,
|
|
21
25
|
created_at=created_at)
|
|
22
26
|
self._urdf_base_filename = urdf_filename
|
|
23
27
|
self._urdf_filename, self._root_path = self._extract()
|
|
24
28
|
self._urdf_filename = os.path.join(self._urdf_filename, self._urdf_filename)
|
|
25
29
|
|
|
26
30
|
def _extract(self):
|
|
27
|
-
working_path = os.path.dirname(self.
|
|
28
|
-
if
|
|
29
|
-
_, ext = os.path.splitext(self.
|
|
31
|
+
working_path = os.path.dirname(self.filename)
|
|
32
|
+
if os.path.exists(self.filename):
|
|
33
|
+
_, ext = os.path.splitext(self.filename)
|
|
30
34
|
ext = ext.lstrip('.') # Remove leading dot
|
|
31
35
|
if ext.lower() != "urdf":
|
|
32
|
-
extract_archive(self.
|
|
36
|
+
extract_archive(self.filename, working_path)
|
|
33
37
|
try:
|
|
34
|
-
os.remove(self.
|
|
38
|
+
os.remove(self.filename)
|
|
35
39
|
except Exception as e:
|
|
36
|
-
print(f"Error removing saved file {self.
|
|
40
|
+
print(f"Error removing saved file {self.filename}: {e}")
|
|
37
41
|
local_filename = self.search_resource_file(self._urdf_base_filename, working_path)
|
|
38
42
|
if local_filename is None:
|
|
39
43
|
raise ValueError(f"Resource filename {self._urdf_base_filename} not found in {working_path}")
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from humalab.constants import DEFAULT_PROJECT
|
|
2
|
+
from humalab.assets.files.resource_file import ResourceFile, ResourceType
|
|
3
|
+
from humalab.humalab_config import HumalabConfig
|
|
4
|
+
from humalab.humalab_api_client import HumaLabApiClient
|
|
5
|
+
from humalab.assets.files.urdf_file import URDFFile
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _asset_dir(humalab_config: HumalabConfig, name: str, version: int) -> str:
|
|
11
|
+
return os.path.join(humalab_config.workspace_path, "assets", name, f"{version}")
|
|
12
|
+
|
|
13
|
+
def _create_asset_dir(humalab_config: HumalabConfig, name: str, version: int) -> bool:
|
|
14
|
+
asset_dir = _asset_dir(humalab_config, name, version)
|
|
15
|
+
if not os.path.exists(asset_dir):
|
|
16
|
+
os.makedirs(asset_dir, exist_ok=True)
|
|
17
|
+
return True
|
|
18
|
+
return False
|
|
19
|
+
|
|
20
|
+
def download(name: str,
|
|
21
|
+
version: int | None=None,
|
|
22
|
+
project: str = DEFAULT_PROJECT,
|
|
23
|
+
|
|
24
|
+
host: str | None = None,
|
|
25
|
+
api_key: str | None = None,
|
|
26
|
+
timeout: float | None = None,
|
|
27
|
+
) -> Any:
|
|
28
|
+
humalab_config = HumalabConfig()
|
|
29
|
+
|
|
30
|
+
api_client = HumaLabApiClient(base_url=host,
|
|
31
|
+
api_key=api_key,
|
|
32
|
+
timeout=timeout)
|
|
33
|
+
|
|
34
|
+
resource = api_client.get_resource(project_name=project, name=name, version=version)
|
|
35
|
+
filename = os.path.basename(resource['resource_url'])
|
|
36
|
+
filename = os.path.join(_asset_dir(humalab_config, name, resource["version"]), filename)
|
|
37
|
+
if _create_asset_dir(humalab_config, name, resource["version"]):
|
|
38
|
+
file_content = api_client.download_resource(project_name=project, name="lerobot")
|
|
39
|
+
with open(filename, "wb") as f:
|
|
40
|
+
f.write(file_content)
|
|
41
|
+
|
|
42
|
+
if resource["resource_type"].lower() == "urdf":
|
|
43
|
+
return URDFFile(project=project,
|
|
44
|
+
name=name,
|
|
45
|
+
version=resource["version"],
|
|
46
|
+
description=resource.get("description"),
|
|
47
|
+
filename=filename,
|
|
48
|
+
urdf_filename=resource.get("filename"),
|
|
49
|
+
created_at=resource.get("created_at"))
|
|
50
|
+
|
|
51
|
+
return ResourceFile(project=project,
|
|
52
|
+
name=name,
|
|
53
|
+
version=resource["version"],
|
|
54
|
+
filename=filename,
|
|
55
|
+
resource_type=resource["resource_type"],
|
|
56
|
+
description=resource.get("description"),
|
|
57
|
+
created_at=resource.get("created_at"))
|
|
58
|
+
|
|
59
|
+
def list_resources(project: str = DEFAULT_PROJECT,
|
|
60
|
+
resource_types: Optional[list[str | ResourceType]] = None,
|
|
61
|
+
limit: int = 20,
|
|
62
|
+
offset: int = 0,
|
|
63
|
+
latest_only: bool = True,
|
|
64
|
+
|
|
65
|
+
host: str | None = None,
|
|
66
|
+
api_key: str | None = None,
|
|
67
|
+
timeout: float | None = None,) -> list[ResourceFile]:
|
|
68
|
+
api_client = HumaLabApiClient(base_url=host,
|
|
69
|
+
api_key=api_key,
|
|
70
|
+
timeout=timeout)
|
|
71
|
+
|
|
72
|
+
resource_type_string = None
|
|
73
|
+
if resource_types:
|
|
74
|
+
resource_type_strings = {rt.value if isinstance(rt, ResourceType) else rt for rt in resource_types}
|
|
75
|
+
resource_type_string = ",".join(resource_type_strings)
|
|
76
|
+
resp = api_client.get_resources(project_name=project,
|
|
77
|
+
resource_types=resource_type_string,
|
|
78
|
+
limit=limit,
|
|
79
|
+
offset=offset,
|
|
80
|
+
latest_only=latest_only)
|
|
81
|
+
resources = resp.get("resources", [])
|
|
82
|
+
ret_list = []
|
|
83
|
+
for resource in resources:
|
|
84
|
+
ret_list.append(ResourceFile(name=resource["name"],
|
|
85
|
+
version=resource.get("version"),
|
|
86
|
+
project=project,
|
|
87
|
+
filename=resource.get("filename"),
|
|
88
|
+
resource_type=resource.get("resource_type"),
|
|
89
|
+
description=resource.get("description"),
|
|
90
|
+
created_at=resource.get("created_at")))
|
|
91
|
+
return ret_list
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
RESERVED_NAMES = {
|
|
5
|
+
"sceanario"
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
DEFAULT_PROJECT = "default"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ArtifactType(Enum):
|
|
12
|
+
"""Types of artifacts that can be stored"""
|
|
13
|
+
METRICS = "metrics" # Run & Episode
|
|
14
|
+
SCENARIO_STATS = "scenario_stats" # Run only
|
|
15
|
+
PYTHON = "python" # Run & Episode
|
|
16
|
+
CODE = "code" # Run & Episode (YAML)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MetricType(Enum):
|
|
20
|
+
METRICS = ArtifactType.METRICS.value
|
|
21
|
+
SCENARIO_STATS = ArtifactType.SCENARIO_STATS.value
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GraphType(Enum):
|
|
25
|
+
"""Types of graphs supported by Humalab."""
|
|
26
|
+
NUMERIC = "numeric"
|
|
27
|
+
LINE = "line"
|
|
28
|
+
BAR = "bar"
|
|
29
|
+
SCATTER = "scatter"
|
|
30
|
+
HISTOGRAM = "histogram"
|
|
31
|
+
GAUSSIAN = "gaussian"
|
|
32
|
+
HEATMAP = "heatmap"
|
|
33
|
+
THREE_D_MAP = "3d_map"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MetricDimType(Enum):
|
|
37
|
+
"""Types of metric dimensions"""
|
|
38
|
+
ZERO_D = "0d"
|
|
39
|
+
ONE_D = "1d"
|
|
40
|
+
TWO_D = "2d"
|
|
41
|
+
THREE_D = "3d"
|
|
@@ -20,6 +20,22 @@ class Bernoulli(Distribution):
|
|
|
20
20
|
self._p = p
|
|
21
21
|
self._size = size
|
|
22
22
|
|
|
23
|
+
@staticmethod
|
|
24
|
+
def validate(dimensions: int, *args) -> bool:
|
|
25
|
+
arg1 = args[0]
|
|
26
|
+
if dimensions == 0:
|
|
27
|
+
if not isinstance(arg1, (int, float)):
|
|
28
|
+
return False
|
|
29
|
+
return True
|
|
30
|
+
if dimensions == -1:
|
|
31
|
+
return True
|
|
32
|
+
if not isinstance(arg1, (int, float)):
|
|
33
|
+
if isinstance(arg1, (list, np.ndarray)):
|
|
34
|
+
if len(arg1) != dimensions:
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
return True
|
|
38
|
+
|
|
23
39
|
def _sample(self) -> int | float | np.ndarray:
|
|
24
40
|
return self._generator.binomial(n=1, p=self._p, size=self._size)
|
|
25
41
|
|
|
@@ -25,6 +25,10 @@ class Categorical(Distribution):
|
|
|
25
25
|
weights = [w / weight_sum for w in weights]
|
|
26
26
|
self._weights = weights
|
|
27
27
|
|
|
28
|
+
@staticmethod
|
|
29
|
+
def validate(dimensions: int, *args) -> bool:
|
|
30
|
+
return True
|
|
31
|
+
|
|
28
32
|
def _sample(self) -> int | float | np.ndarray:
|
|
29
33
|
return self._generator.choice(self._choices, size=self._size, p=self._weights)
|
|
30
34
|
|
|
@@ -26,6 +26,28 @@ class Discrete(Distribution):
|
|
|
26
26
|
self._high = np.array(high)
|
|
27
27
|
self._size = size
|
|
28
28
|
self._endpoint = endpoint if endpoint is not None else True
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def validate(dimensions: int, *args) -> bool:
|
|
32
|
+
arg1 = args[0]
|
|
33
|
+
arg2 = args[1]
|
|
34
|
+
if dimensions == 0:
|
|
35
|
+
if not isinstance(arg1, int):
|
|
36
|
+
return False
|
|
37
|
+
if not isinstance(arg2, int):
|
|
38
|
+
return False
|
|
39
|
+
return True
|
|
40
|
+
if dimensions == -1:
|
|
41
|
+
return True
|
|
42
|
+
if not isinstance(arg1, int):
|
|
43
|
+
if isinstance(arg1, (list, np.ndarray)):
|
|
44
|
+
if len(arg1) != dimensions:
|
|
45
|
+
return False
|
|
46
|
+
if not isinstance(arg2, int):
|
|
47
|
+
if isinstance(arg2, (list, np.ndarray)):
|
|
48
|
+
if len(arg2) != dimensions:
|
|
49
|
+
return False
|
|
50
|
+
return True
|
|
29
51
|
|
|
30
52
|
def _sample(self) -> int | float | np.ndarray:
|
|
31
53
|
return self._generator.integers(self._low, self._high, size=self._size, endpoint=self._endpoint)
|
|
@@ -23,6 +23,28 @@ class Gaussian(Distribution):
|
|
|
23
23
|
self._scale = scale
|
|
24
24
|
self._size = size
|
|
25
25
|
|
|
26
|
+
@staticmethod
|
|
27
|
+
def validate(dimensions: int, *args) -> bool:
|
|
28
|
+
arg1 = args[0]
|
|
29
|
+
arg2 = args[1]
|
|
30
|
+
if dimensions == 0:
|
|
31
|
+
if not isinstance(arg1, (int, float)):
|
|
32
|
+
return False
|
|
33
|
+
if not isinstance(arg2, (int, float)):
|
|
34
|
+
return False
|
|
35
|
+
return True
|
|
36
|
+
if dimensions == -1:
|
|
37
|
+
return True
|
|
38
|
+
if not isinstance(arg1, (int, float)):
|
|
39
|
+
if isinstance(arg1, (list, np.ndarray)):
|
|
40
|
+
if len(arg1) != dimensions:
|
|
41
|
+
return False
|
|
42
|
+
if not isinstance(arg2, (int, float)):
|
|
43
|
+
if isinstance(arg2, (list, np.ndarray)):
|
|
44
|
+
if len(arg2) != dimensions:
|
|
45
|
+
return False
|
|
46
|
+
return True
|
|
47
|
+
|
|
26
48
|
def _sample(self) -> int | float | np.ndarray:
|
|
27
49
|
return self._generator.normal(loc=self._loc, scale=self._scale, size=self._size)
|
|
28
50
|
|
|
@@ -22,6 +22,28 @@ class LogUniform(Distribution):
|
|
|
22
22
|
self._log_low = np.log(np.array(low))
|
|
23
23
|
self._log_high = np.log(np.array(high))
|
|
24
24
|
self._size = size
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def validate(dimensions: int, *args) -> bool:
|
|
28
|
+
arg1 = args[0]
|
|
29
|
+
arg2 = args[1]
|
|
30
|
+
if dimensions == 0:
|
|
31
|
+
if not isinstance(arg1, (int, float)):
|
|
32
|
+
return False
|
|
33
|
+
if not isinstance(arg2, (int, float)):
|
|
34
|
+
return False
|
|
35
|
+
return True
|
|
36
|
+
if dimensions == -1:
|
|
37
|
+
return True
|
|
38
|
+
if not isinstance(arg1, (int, float)):
|
|
39
|
+
if isinstance(arg1, (list, np.ndarray)):
|
|
40
|
+
if len(arg1) != dimensions:
|
|
41
|
+
return False
|
|
42
|
+
if not isinstance(arg2, (int, float)):
|
|
43
|
+
if isinstance(arg2, (list, np.ndarray)):
|
|
44
|
+
if len(arg2) != dimensions:
|
|
45
|
+
return False
|
|
46
|
+
return True
|
|
25
47
|
|
|
26
48
|
def _sample(self) -> int | float | np.ndarray:
|
|
27
49
|
return np.exp(self._generator.uniform(self._log_low, self._log_high, size=self._size))
|
|
@@ -29,6 +29,42 @@ class TruncatedGaussian(Distribution):
|
|
|
29
29
|
self._high = high
|
|
30
30
|
self._size = size
|
|
31
31
|
|
|
32
|
+
@staticmethod
|
|
33
|
+
def validate(dimensions: int, *args) -> bool:
|
|
34
|
+
arg1 = args[0]
|
|
35
|
+
arg2 = args[1]
|
|
36
|
+
arg3 = args[2]
|
|
37
|
+
arg4 = args[3]
|
|
38
|
+
if dimensions == 0:
|
|
39
|
+
if not isinstance(arg1, (int, float)):
|
|
40
|
+
return False
|
|
41
|
+
if not isinstance(arg2, (int, float)):
|
|
42
|
+
return False
|
|
43
|
+
if not isinstance(arg3, (int, float)):
|
|
44
|
+
return False
|
|
45
|
+
if not isinstance(arg4, (int, float)):
|
|
46
|
+
return False
|
|
47
|
+
return True
|
|
48
|
+
if dimensions == -1:
|
|
49
|
+
return True
|
|
50
|
+
if not isinstance(arg1, (int, float)):
|
|
51
|
+
if isinstance(arg1, (list, np.ndarray)):
|
|
52
|
+
if len(arg1) != dimensions:
|
|
53
|
+
return False
|
|
54
|
+
if not isinstance(arg2, (int, float)):
|
|
55
|
+
if isinstance(arg2, (list, np.ndarray)):
|
|
56
|
+
if len(arg2) != dimensions:
|
|
57
|
+
return False
|
|
58
|
+
if not isinstance(arg3, (int, float)):
|
|
59
|
+
if isinstance(arg3, (list, np.ndarray)):
|
|
60
|
+
if len(arg3) != dimensions:
|
|
61
|
+
return False
|
|
62
|
+
if not isinstance(arg4, (int, float)):
|
|
63
|
+
if isinstance(arg4, (list, np.ndarray)):
|
|
64
|
+
if len(arg4) != dimensions:
|
|
65
|
+
return False
|
|
66
|
+
return True
|
|
67
|
+
|
|
32
68
|
def _sample(self) -> int | float | np.ndarray:
|
|
33
69
|
samples = self._generator.normal(loc=self._loc, scale=self._scale, size=self._size)
|
|
34
70
|
mask = (samples < self._low) | (samples > self._high)
|
|
@@ -23,6 +23,28 @@ class Uniform(Distribution):
|
|
|
23
23
|
self._high = np.array(high)
|
|
24
24
|
self._size = size
|
|
25
25
|
|
|
26
|
+
@staticmethod
|
|
27
|
+
def validate(dimensions: int, *args) -> bool:
|
|
28
|
+
arg1 = args[0]
|
|
29
|
+
arg2 = args[1]
|
|
30
|
+
if dimensions == 0:
|
|
31
|
+
if not isinstance(arg1, (int, float)):
|
|
32
|
+
return False
|
|
33
|
+
if not isinstance(arg2, (int, float)):
|
|
34
|
+
return False
|
|
35
|
+
return True
|
|
36
|
+
if dimensions == -1:
|
|
37
|
+
return True
|
|
38
|
+
if not isinstance(arg1, (int, float)):
|
|
39
|
+
if isinstance(arg1, (list, np.ndarray)):
|
|
40
|
+
if len(arg1) > dimensions:
|
|
41
|
+
return False
|
|
42
|
+
if not isinstance(arg2, (int, float)):
|
|
43
|
+
if isinstance(arg2, (list, np.ndarray)):
|
|
44
|
+
if len(arg2) > dimensions:
|
|
45
|
+
return False
|
|
46
|
+
return True
|
|
47
|
+
|
|
26
48
|
def _sample(self) -> int | float | np.ndarray:
|
|
27
49
|
return self._generator.uniform(self._low, self._high, size=self._size)
|
|
28
50
|
|