libinephany 1.1.2__tar.gz → 1.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. libinephany-1.1.4/CODE_VERSION.cfg +1 -0
  2. {libinephany-1.1.2/libinephany.egg-info → libinephany-1.1.4}/PKG-INFO +1 -1
  3. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observation_utils.py +64 -7
  4. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/base_observers.py +47 -0
  5. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/statistic_trackers.py +13 -3
  6. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/inner_task_profile.py +39 -55
  7. {libinephany-1.1.2 → libinephany-1.1.4/libinephany.egg-info}/PKG-INFO +1 -1
  8. libinephany-1.1.2/CODE_VERSION.cfg +0 -1
  9. {libinephany-1.1.2 → libinephany-1.1.4}/LICENSE +0 -0
  10. {libinephany-1.1.2 → libinephany-1.1.4}/MANIFEST.in +0 -0
  11. {libinephany-1.1.2 → libinephany-1.1.4}/README.md +0 -0
  12. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/__init__.py +0 -0
  13. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/aws/__init__.py +0 -0
  14. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/aws/s3_functions.py +0 -0
  15. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/__init__.py +0 -0
  16. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observer_pipeline.py +0 -0
  17. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/__init__.py +0 -0
  18. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/__init__.py +0 -0
  19. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/base_classes.py +0 -0
  20. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/constants.py +0 -0
  21. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/gradient_observers.py +0 -0
  22. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/hyperparameter_observers.py +0 -0
  23. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/loss_observers.py +0 -0
  24. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/model_observers.py +0 -0
  25. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/global_observers/progress_observers.py +0 -0
  26. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/local_observers.py +0 -0
  27. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/observers/observer_containers.py +0 -0
  28. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/pipeline_coordinator.py +0 -0
  29. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/post_processors/__init__.py +0 -0
  30. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/post_processors/postprocessors.py +0 -0
  31. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/observations/statistic_manager.py +0 -0
  32. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/__init__.py +0 -0
  33. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/configs/__init__.py +0 -0
  34. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
  35. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/configs/observer_config.py +0 -0
  36. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
  37. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/__init__.py +0 -0
  38. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
  39. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
  40. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
  41. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
  42. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
  43. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/states/__init__.py +0 -0
  44. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/pydantic_models/states/hyperparameter_states.py +0 -0
  45. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/__init__.py +0 -0
  46. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/agent_utils.py +0 -0
  47. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/asyncio_worker.py +0 -0
  48. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/backend_statuses.py +0 -0
  49. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/constants.py +0 -0
  50. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/directory_utils.py +0 -0
  51. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/dropout_utils.py +0 -0
  52. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/enums.py +0 -0
  53. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/error_severities.py +0 -0
  54. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/exceptions.py +0 -0
  55. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/import_utils.py +0 -0
  56. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/optim_utils.py +0 -0
  57. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/random_seeds.py +0 -0
  58. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/samplers.py +0 -0
  59. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/standardizers.py +0 -0
  60. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/torch_distributed_utils.py +0 -0
  61. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/torch_utils.py +0 -0
  62. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/transforms.py +0 -0
  63. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/utils/typing.py +0 -0
  64. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/web_apps/__init__.py +0 -0
  65. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/web_apps/error_logger.py +0 -0
  66. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany/web_apps/web_app_utils.py +0 -0
  67. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany.egg-info/SOURCES.txt +0 -0
  68. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany.egg-info/dependency_links.txt +0 -0
  69. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany.egg-info/requires.txt +0 -0
  70. {libinephany-1.1.2 → libinephany-1.1.4}/libinephany.egg-info/top_level.txt +0 -0
  71. {libinephany-1.1.2 → libinephany-1.1.4}/pyproject.toml +0 -0
  72. {libinephany-1.1.2 → libinephany-1.1.4}/setup.cfg +0 -0
@@ -0,0 +1 @@
1
+ 1.1.4
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 1.1.2
3
+ Version: 1.1.4
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -13,6 +13,7 @@ import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
15
  import torch.optim as optim
16
+ from loguru import logger
16
17
  from scipy.stats import norm
17
18
 
18
19
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
@@ -60,18 +61,54 @@ class StatisticStorageTypes(Enum):
60
61
  # ======================================================================================================================
61
62
 
62
63
 
63
- def get_exponential_weighted_average(values: list[int | float]) -> float:
64
+ def get_exponential_weighted_average(
65
+ values: list[int | float], invalid_value_threshold: float = 1e10, tracker_name: str = "Unknown"
66
+ ) -> float:
64
67
  """
65
68
  :param values: List of values to average via EWA.
69
+ :param invalid_value_threshold: Threshold for invalid observations, default is 1e10.
70
+ :param tracker_name: Name of the tracker for error reporting, default is "Unknown".
66
71
  :return: EWA of the given values.
67
72
  """
73
+
74
+ # Logging statistic tracker name when input is invalid.
75
+
76
+ if len(values) == 0:
77
+ raise ValueError(
78
+ f"Statistic Tracker: {tracker_name} gathered data with empty list! It likely means a bug is triggered in the"
79
+ f" code related to {tracker_name} or the inner task. Please check up the code!"
80
+ )
81
+
82
+ if any(not isinstance(value, (float, int)) for value in values):
83
+ # Should never happen, but just in case
84
+ raise ValueError(
85
+ f"Statistic Tracker: {tracker_name} gathered data with invalid values (not int or float)! It likely means a"
86
+ f" bug is triggered in the code related to {tracker_name} or the inner task. Please check up the code!"
87
+ )
88
+
89
+ if any(abs(value) > invalid_value_threshold for value in values):
90
+ # check for large values (including negative values)
91
+ logger.warning(
92
+ f"Statistic Tracker: {tracker_name} gathered data with values out of range (invalid_value_threshold: "
93
+ f"{invalid_value_threshold})! May cause episode termination if StopEpisodeFromInvalidObservations is used."
94
+ )
95
+
96
+ if any(np.isnan(value) for value in values):
97
+ logger.warning(
98
+ f"Statistic Tracker: {tracker_name} gathered data with NaN values! May cause episode termination if "
99
+ f"StopEpisodeFromInvalidObservations is used."
100
+ )
101
+
68
102
  exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
69
103
  assert isinstance(exp_weighted_average, float)
70
104
  return exp_weighted_average
71
105
 
72
106
 
73
107
  def apply_averaging_function_to_tensor_statistics(
74
- tensor_statistics: list[TensorStatistics], averaging_function: Callable[[list[float]], float]
108
+ tensor_statistics: list[TensorStatistics],
109
+ averaging_function: Callable[[list[float], float, str], float],
110
+ invalid_value_threshold: float = 1e10,
111
+ tracker_name: str = "Unknown",
75
112
  ) -> TensorStatistics:
76
113
  """
77
114
  :param tensor_statistics: List of statistics models to average over.
@@ -81,40 +118,60 @@ def apply_averaging_function_to_tensor_statistics(
81
118
 
82
119
  fields = TensorStatistics.model_fields.keys()
83
120
  averaged_metrics = {
84
- field: averaging_function([getattr(statistics, field) for statistics in tensor_statistics]) for field in fields
121
+ field: averaging_function(
122
+ [getattr(statistics, field) for statistics in tensor_statistics],
123
+ invalid_value_threshold,
124
+ tracker_name,
125
+ )
126
+ for field in fields
85
127
  }
86
128
 
87
129
  return TensorStatistics(**averaged_metrics)
88
130
 
89
131
 
90
132
  def apply_averaging_function_to_dictionary_of_tensor_statistics(
91
- data: dict[str, list[TensorStatistics]], averaging_function: Callable[[list[float]], float]
133
+ data: dict[str, list[TensorStatistics]],
134
+ averaging_function: Callable[[list[float], float, str], float],
135
+ invalid_value_threshold: float = 1e10,
136
+ tracker_name: str = "Unknown",
92
137
  ) -> dict[str, TensorStatistics]:
93
138
  """
94
139
  :param data: Dictionary mapping parameter group names to list of TensorStatistics from that parameter group.
95
140
  :param averaging_function: Function to average the values with.
141
+ :param invalid_value_threshold: Threshold for invalid observations.
142
+ :param tracker_name: Name of the tracker for error reporting.
96
143
  :return: Dictionary mapping parameter group names to TensorStatistics averaged over all statistics in the given
97
144
  TensorStatistics models.
98
145
  """
99
146
 
100
147
  return {
101
148
  group: apply_averaging_function_to_tensor_statistics(
102
- tensor_statistics=metrics, averaging_function=averaging_function
149
+ tensor_statistics=metrics,
150
+ averaging_function=averaging_function,
151
+ invalid_value_threshold=invalid_value_threshold,
152
+ tracker_name=tracker_name,
103
153
  )
104
154
  for group, metrics in data.items()
105
155
  }
106
156
 
107
157
 
108
158
  def apply_averaging_function_to_dictionary_of_metric_lists(
109
- data: dict[str, list[float]], averaging_function: Callable[[list[float]], float]
159
+ data: dict[str, list[float]],
160
+ averaging_function: Callable[[list[float], float, str], float],
161
+ invalid_value_threshold: float = 1e10,
162
+ tracker_name: str = "Unknown",
110
163
  ) -> dict[str, float]:
111
164
  """
112
165
  :param data: Dictionary mapping parameter group names to list of metrics from that parameter group.
113
166
  :param averaging_function: Function to average the values with.
167
+ :param invalid_value_threshold: Threshold for invalid observations.
168
+ :param tracker_name: Name of the tracker for error reporting.
114
169
  :return: Dictionary mapping parameter group names to averages over all metrics from each parameter group.
115
170
  """
116
171
 
117
- return {group: averaging_function(metrics) for group, metrics in data.items()}
172
+ return {
173
+ group: averaging_function(metrics, invalid_value_threshold, tracker_name) for group, metrics in data.items()
174
+ }
118
175
 
119
176
 
120
177
  def average_tensor_statistics(tensor_statistics: list[TensorStatistics]) -> TensorStatistics:
@@ -8,6 +8,9 @@ from abc import ABC, abstractmethod
8
8
  from copy import deepcopy
9
9
  from typing import Any, TypeAlias, final
10
10
 
11
+ import numpy as np
12
+ from loguru import logger
13
+
11
14
  from libinephany.observations import observation_utils
12
15
  from libinephany.observations.observation_utils import StatisticStorageTypes
13
16
  from libinephany.pydantic_models.configs.observer_config import ObserverConfig
@@ -65,6 +68,7 @@ class Observer(ABC):
65
68
  self.observer_config = observer_config
66
69
  self.standardize = standardizer if standardizer is not None else observation_utils.null_standardizer
67
70
  self.should_standardize = should_standardize and self.can_standardize
71
+ self.invalid_observation_threshold = observer_config.invalid_observation_threshold
68
72
 
69
73
  self.include_statistics: list[str] | None = None
70
74
  self.include_hparams = include_hparams
@@ -220,6 +224,47 @@ class Observer(ABC):
220
224
 
221
225
  self._validated_observation = True
222
226
 
227
+ @final
228
+ def _validate_observation_values(self, observations: list[float | int]) -> None:
229
+ """
230
+ :param observations: Observation vector to validate.
231
+ """
232
+ validate = True
233
+ if len(observations) == 0:
234
+ raise ValueError(f"Observer: {self.__class__.__name__} gathered observations with empty list!")
235
+
236
+ if any(not isinstance(observation, (float, int)) for observation in observations):
237
+ # Just in case
238
+ other_count = sum(1 for observation in observations if not isinstance(observation, (float, int)))
239
+ other_ratio = other_count / len(observations)
240
+ raise ValueError(
241
+ f"Observer: {self.__class__.__name__} gathered observations with invalid values (not float or int)! "
242
+ f"Ratio: {other_ratio:.2%} ({other_count}/{len(observations)})"
243
+ )
244
+
245
+ # Check for NaN values
246
+ if any(np.isnan(observation) for observation in observations):
247
+ nan_count = sum(1 for observation in observations if np.isnan(observation))
248
+ nan_ratio = nan_count / len(observations)
249
+ logger.warning(
250
+ f"Observer: {self.__class__.__name__} gathered observations with NaN values! Ratio: {nan_ratio:.2%} "
251
+ f" ({nan_count}/{len(observations)})"
252
+ )
253
+ validate = False
254
+
255
+ # Check for very large values (including negative values)
256
+ if any(abs(observation) > self.invalid_observation_threshold for observation in observations):
257
+ inf_count = sum(1 for observation in observations if observation > self.invalid_observation_threshold)
258
+ inf_ratio = inf_count / len(observations)
259
+ logger.warning(
260
+ f"Observer: {self.__class__.__name__} gathered observations with large values (threshold: "
261
+ f"{self.invalid_observation_threshold})! Ratio: {inf_ratio:.2%} ({inf_count}/{len(observations)})"
262
+ )
263
+ validate = False
264
+ logger.debug(
265
+ f'Observer: {self.__class__.__name__} observation validation {"passed" if validate else "failed"}!'
266
+ )
267
+
223
268
  @abstractmethod
224
269
  def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
225
270
  """
@@ -285,6 +330,8 @@ class Observer(ABC):
285
330
  if not self._validated_observation:
286
331
  self._validate_observation(observations=observations)
287
332
 
333
+ self._validate_observation_values(observations=observations)
334
+
288
335
  return observations, observations_dict
289
336
 
290
337
  @final
@@ -45,11 +45,12 @@ class Statistic(ABC):
45
45
  self,
46
46
  *,
47
47
  parameter_group_name: str,
48
- averaging_function: Callable[[list[float]], float],
48
+ averaging_function: Callable[[list[float], float, str], float],
49
49
  agent_modules: dict[str, str],
50
50
  max_statistic_cache_size: int = 3,
51
51
  tensor_stats_downsample_percentage: float = 0.01,
52
52
  statistic_sample_frequency: int = 10,
53
+ invalid_observation_threshold: float = 1e10,
53
54
  **kwargs,
54
55
  ) -> None:
55
56
  """
@@ -62,6 +63,7 @@ class Statistic(ABC):
62
63
  sample to compute the statistics from.
63
64
  :param statistic_sample_frequency: How frequently to cache tensors from the model for later statistic
64
65
  computation.
66
+ :param invalid_observation_threshold: Threshold for invalid observations, default is 1e10.
65
67
  :param kwargs: Other observation keyword arguments.
66
68
  """
67
69
 
@@ -77,6 +79,7 @@ class Statistic(ABC):
77
79
  self.downsample_percent = tensor_stats_downsample_percentage
78
80
  self.sample_frequency = statistic_sample_frequency
79
81
  self.include_statistics: list[str] | None = None
82
+ self.invalid_observation_threshold = invalid_observation_threshold
80
83
 
81
84
  @final
82
85
  @property
@@ -324,11 +327,18 @@ class Statistic(ABC):
324
327
  self._process_tensor_cache()
325
328
 
326
329
  results = observation_utils.apply_averaging_function_to_tensor_statistics(
327
- tensor_statistics=self._data, averaging_function=self._averaging_function # type: ignore
330
+ tensor_statistics=self._data, # type: ignore
331
+ averaging_function=self._averaging_function,
332
+ invalid_value_threshold=self.invalid_observation_threshold,
333
+ tracker_name=self.tracker_name,
328
334
  )
329
335
 
330
336
  elif self.storage_format is StatisticStorageTypes.FLOAT:
331
- results = self._averaging_function(self._data) # type: ignore
337
+ results = self._averaging_function(
338
+ self._data, # type: ignore
339
+ invalid_value_threshold=self.invalid_observation_threshold,
340
+ tracker_name=self.tracker_name,
341
+ )
332
342
 
333
343
  else:
334
344
  raise ValueError(f"Data storage type {self.storage_format} is invalid!")
@@ -7,8 +7,15 @@
7
7
  import math
8
8
  from typing import Any, Callable
9
9
 
10
- from loguru import logger
11
- from pydantic import BaseModel, field_validator
10
+ from pydantic import BaseModel
11
+
12
+ # ======================================================================================================================
13
+ #
14
+ # CONSTANTS
15
+ #
16
+ # ======================================================================================================================
17
+
18
+ VRAM_USAGES_KEY = "vram_usages"
12
19
 
13
20
  # ======================================================================================================================
14
21
  #
@@ -20,26 +27,40 @@ from pydantic import BaseModel, field_validator
20
27
  class InnerTaskProfile(BaseModel):
21
28
 
22
29
  inner_task_name: str
30
+
23
31
  number_of_agents: int
24
32
  number_of_layers: int
33
+ number_of_parameters: int
34
+
25
35
  observation_space_sizes: dict[str, int]
26
36
  action_space_sizes: dict[str, int]
27
- number_of_parameters: int
28
- vram_usage: float
29
- idle_vram_usage: float
30
- hparam_overrides: dict[str, dict[str, Any]] | None = None
31
37
 
32
- @field_validator("vram_usage", "idle_vram_usage", mode="before")
33
- def replace_non_with_nan(cls, value: None | float) -> float:
38
+ vram_usages: dict[int, tuple[float, float]] | None = None
39
+ expected_vram_usage: float | None = None
40
+ expected_idle_vram_usage: float | None = None
41
+ max_batch_size_override: int | None = None
42
+
43
+ @property
44
+ def vram_usage(self) -> float:
45
+ """
46
+ :return: VRAM usage at the max batch size.
47
+ """
48
+
49
+ if self.expected_vram_usage is None:
50
+ return float("nan")
51
+
52
+ return self.expected_vram_usage
53
+
54
+ @property
55
+ def idle_vram_usage(self) -> float:
34
56
  """
35
- :param value: Value to replace with NaN if it is None.
36
- :return: Either the given float value or NaN.
57
+ :return: Idle VRAM usage at the max batch size.
37
58
  """
38
59
 
39
- if value is None:
60
+ if self.expected_idle_vram_usage is None:
40
61
  return float("nan")
41
62
 
42
- return value
63
+ return self.expected_idle_vram_usage
43
64
 
44
65
  @property
45
66
  def failed_to_profile(self) -> bool:
@@ -49,19 +70,18 @@ class InnerTaskProfile(BaseModel):
49
70
 
50
71
  return math.isnan(self.vram_usage)
51
72
 
52
- def model_dump_json(self, **kwargs) -> str:
73
+ def model_dump(self, **kwargs) -> dict[str, Any]:
53
74
  """
54
75
  :param kwargs: Standard Pydantic model dump kwargs.
55
76
  :return: Dump result of the superclass' method.
56
77
  """
57
78
 
58
- logger.debug(
59
- f"Inner task {self.inner_task_name} consumed {self.vram_usage:.3f} MB of VRAM while training and "
60
- f"{self.idle_vram_usage:.3f} MB of VRAM while idle. It has {self.number_of_agents} agents across "
61
- f"{self.number_of_layers} inner model layers."
62
- )
79
+ super_dump = super().model_dump(**kwargs)
63
80
 
64
- return super().model_dump_json(**kwargs)
81
+ if self.vram_usages is not None:
82
+ super_dump[VRAM_USAGES_KEY] = {k: list(v) for k, v in self.vram_usages.items()}
83
+
84
+ return super_dump
65
85
 
66
86
 
67
87
  class InnerTaskProfiles(BaseModel):
@@ -235,42 +255,6 @@ class InnerTaskProfiles(BaseModel):
235
255
 
236
256
  return inner_task_name in self.profiles
237
257
 
238
- def add_profile(
239
- self,
240
- inner_task_name: str,
241
- number_of_agents: int,
242
- number_of_layers: int,
243
- observation_space_sizes: dict[str, int],
244
- action_space_sizes: dict[str, int],
245
- number_of_parameters: int,
246
- vram_usage: float,
247
- idle_vram_usage: float,
248
- hparam_overrides: dict[str, dict[str, Any]] | None = None,
249
- ) -> None:
250
- """
251
- :param inner_task_name: Name of the inner task to add a profile for.
252
- :param number_of_agents: Number of agents active in the inner task's environment.
253
- :param number_of_layers: Number of layers in the inner model.
254
- :param observation_space_sizes: Dictionary mapping agent IDs to their observation space sizes.
255
- :param action_space_sizes: Dictionary mapping agent IDs to their action space sizes.
256
- :param vram_usage: VRAM required to perform the inner task. Can be NaN if an OOM was encountered.
257
- :param idle_vram_usage: VRAM required for the inner task to sit loaded but not actively being trained. Can be
258
- NaN if an OOM was encountered.
259
- :param hparam_overrides: Hyperparameter overrides for the inner task.
260
- """
261
-
262
- self.profiles[inner_task_name] = InnerTaskProfile(
263
- inner_task_name=inner_task_name,
264
- number_of_agents=number_of_agents,
265
- number_of_layers=number_of_layers,
266
- observation_space_sizes=observation_space_sizes,
267
- action_space_sizes=action_space_sizes,
268
- number_of_parameters=number_of_parameters,
269
- vram_usage=vram_usage,
270
- idle_vram_usage=idle_vram_usage,
271
- hparam_overrides=hparam_overrides,
272
- )
273
-
274
258
  def validate_task_profiles(self, policy_mapping_function: Callable[[str, Any, Any], str]) -> None:
275
259
  """
276
260
  :param policy_mapping_function: Function which maps agent IDs to policy IDs.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 1.1.2
3
+ Version: 1.1.4
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -1 +0,0 @@
1
- 1.1.2
File without changes
File without changes
File without changes
File without changes
File without changes