rock-physics-open 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rock-physics-open might be problematic. Click here for more details.

Files changed (42) hide show
  1. rock_physics_open/equinor_utilities/gen_utilities/dict_to_float.py +6 -1
  2. rock_physics_open/equinor_utilities/gen_utilities/dim_check_vector.py +35 -5
  3. rock_physics_open/equinor_utilities/gen_utilities/filter_input.py +11 -6
  4. rock_physics_open/equinor_utilities/gen_utilities/filter_output.py +29 -19
  5. rock_physics_open/equinor_utilities/machine_learning_utilities/__init__.py +18 -5
  6. rock_physics_open/equinor_utilities/machine_learning_utilities/base_pressure_model.py +172 -0
  7. rock_physics_open/equinor_utilities/machine_learning_utilities/exponential_model.py +100 -86
  8. rock_physics_open/equinor_utilities/machine_learning_utilities/friable_pressure_models.py +230 -0
  9. rock_physics_open/equinor_utilities/machine_learning_utilities/import_ml_models.py +23 -4
  10. rock_physics_open/equinor_utilities/machine_learning_utilities/patchy_cement_pressure_models.py +280 -0
  11. rock_physics_open/equinor_utilities/machine_learning_utilities/polynomial_model.py +128 -0
  12. rock_physics_open/equinor_utilities/machine_learning_utilities/sigmoidal_model.py +204 -155
  13. rock_physics_open/equinor_utilities/optimisation_utilities/__init__.py +19 -0
  14. rock_physics_open/equinor_utilities/snapshot_test_utilities/compare_snapshots.py +1 -2
  15. rock_physics_open/fluid_models/brine_model/brine_properties.py +70 -35
  16. rock_physics_open/fluid_models/gas_model/gas_properties.py +79 -37
  17. rock_physics_open/fluid_models/oil_model/dead_oil_density.py +21 -16
  18. rock_physics_open/fluid_models/oil_model/dead_oil_velocity.py +9 -7
  19. rock_physics_open/fluid_models/oil_model/live_oil_density.py +16 -13
  20. rock_physics_open/fluid_models/oil_model/live_oil_velocity.py +3 -3
  21. rock_physics_open/fluid_models/oil_model/oil_properties.py +59 -29
  22. rock_physics_open/sandstone_models/__init__.py +2 -0
  23. rock_physics_open/sandstone_models/constant_cement_optimisation.py +4 -1
  24. rock_physics_open/sandstone_models/friable_optimisation.py +4 -1
  25. rock_physics_open/sandstone_models/patchy_cement_model.py +89 -6
  26. rock_physics_open/sandstone_models/patchy_cement_optimisation.py +4 -1
  27. rock_physics_open/t_matrix_models/__init__.py +0 -10
  28. rock_physics_open/t_matrix_models/carbonate_pressure_substitution.py +1 -1
  29. rock_physics_open/t_matrix_models/curvefit_t_matrix_exp.py +1 -2
  30. rock_physics_open/t_matrix_models/t_matrix_opt_fluid_sub_exp.py +3 -3
  31. rock_physics_open/t_matrix_models/t_matrix_opt_fluid_sub_petec.py +5 -1
  32. rock_physics_open/t_matrix_models/t_matrix_opt_forward_model_exp.py +5 -1
  33. rock_physics_open/t_matrix_models/t_matrix_opt_forward_model_min.py +4 -1
  34. rock_physics_open/t_matrix_models/t_matrix_parameter_optimisation_exp.py +5 -1
  35. rock_physics_open/t_matrix_models/t_matrix_parameter_optimisation_min.py +4 -1
  36. rock_physics_open/version.py +2 -2
  37. {rock_physics_open-0.2.3.dist-info → rock_physics_open-0.3.0.dist-info}/METADATA +4 -8
  38. {rock_physics_open-0.2.3.dist-info → rock_physics_open-0.3.0.dist-info}/RECORD +42 -37
  39. /rock_physics_open/{t_matrix_models → equinor_utilities/optimisation_utilities}/opt_subst_utilities.py +0 -0
  40. {rock_physics_open-0.2.3.dist-info → rock_physics_open-0.3.0.dist-info}/WHEEL +0 -0
  41. {rock_physics_open-0.2.3.dist-info → rock_physics_open-0.3.0.dist-info}/licenses/LICENSE +0 -0
  42. {rock_physics_open-0.2.3.dist-info → rock_physics_open-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,9 @@
1
- def dict_value_to_float(input_dict):
1
+ from typing import Any
2
+
3
+
4
+ def dict_value_to_float(
5
+ input_dict: dict[str, Any],
6
+ ) -> dict[str, float | list[float]]:
2
7
  """
3
8
  Convert dictionary strings to floating point numbers. Each value can have multiple floats.
4
9
 
@@ -1,8 +1,38 @@
1
+ from typing import Any, overload
2
+
1
3
  import numpy as np
4
+ import numpy.typing as npt
2
5
  import pandas as pd
3
6
 
4
7
 
5
- def dim_check_vector(args, force_type=None):
8
+ @overload
9
+ def dim_check_vector(
10
+ args: list[Any] | tuple[Any, ...],
11
+ force_type: np.dtype | None = ...,
12
+ ) -> list[npt.NDArray[Any] | pd.DataFrame]:
13
+ """Overload for when the input is a list or tuple."""
14
+
15
+
16
+ @overload
17
+ def dim_check_vector(
18
+ args: pd.DataFrame,
19
+ force_type: np.dtype | None = ...,
20
+ ) -> pd.DataFrame:
21
+ """Overload for when the input is a pandas DataFrame."""
22
+
23
+
24
+ @overload
25
+ def dim_check_vector(
26
+ args: npt.NDArray[Any],
27
+ force_type: np.dtype | None = ...,
28
+ ) -> npt.NDArray[Any]:
29
+ """Overload for when the input is a numpy array."""
30
+
31
+
32
+ def dim_check_vector(
33
+ args: list[Any] | tuple[Any, ...] | npt.NDArray[Any] | pd.DataFrame,
34
+ force_type: np.dtype | None = None,
35
+ ) -> npt.NDArray[Any] | pd.DataFrame | list[npt.NDArray[Any] | pd.DataFrame]:
6
36
  """
7
37
  Check that all inputs are of the same (one-dimensional) size. Raise ValueError in case there are several lengths
8
38
  present in the inputs. All inputs will be checked and possibly expanded to common length. Only the first dimension
@@ -23,8 +53,8 @@ def dim_check_vector(args, force_type=None):
23
53
  single_types = (np.ndarray, pd.DataFrame)
24
54
  iterable_types = (list, tuple)
25
55
  allowed_types = single_types + iterable_types
26
- if not isinstance(args, allowed_types):
27
- raise ValueError("dim_check_vector: unknown input type: {}".format(type(args)))
56
+ if not isinstance(args, allowed_types): # pyright: ignore[reportUnnecessaryIsInstance] | Kept for backward compatibility
57
+ raise ValueError("dim_check_vector: unknown input type: {}".format(type(args))) # pyright: ignore[reportUnreachable] | Kept for backward compatibility
28
58
 
29
59
  # Single array or dataframe is just returned
30
60
  if isinstance(args, single_types):
@@ -56,13 +86,13 @@ def dim_check_vector(args, force_type=None):
56
86
  args = [np.array(item, ndmin=1) if np.isscalar(item) else item for item in args]
57
87
 
58
88
  # Can now test for length - must either be a scalar or have the same length
59
- max_length = np.max([item.shape[0] for item in args])
89
+ max_length: int = np.max([item.shape[0] for item in args])
60
90
  if not np.all([item.shape[0] == max_length or item.shape[0] == 1 for item in args]):
61
91
  raise ValueError(
62
92
  "dim_check_vector: Unequal array lengths in input to dim_check_vector"
63
93
  )
64
94
 
65
- output_arg = []
95
+ output_arg: list[npt.NDArray[Any] | pd.DataFrame] = []
66
96
  for item in args:
67
97
  if item.shape[0] == max_length:
68
98
  output_arg.append(item)
@@ -1,14 +1,20 @@
1
1
  from sys import byteorder
2
+ from typing import Any
2
3
 
3
4
  import numpy as np
5
+ import numpy.typing as npt
4
6
  import pandas as pd
5
7
 
6
8
  WRONG_BYTEORDER = ">" if byteorder == "little" else "<"
7
9
 
8
10
 
9
11
  def filter_input_log(
10
- args, working_int=None, negative=False, no_zero=False, positive=True
11
- ):
12
+ args: list[Any] | tuple[Any, ...] | npt.NDArray[Any] | pd.DataFrame,
13
+ working_int: npt.NDArray[Any] | None = None,
14
+ negative: bool = False,
15
+ no_zero: bool = False,
16
+ positive: bool = True,
17
+ ) -> tuple[npt.NDArray[np.bool_], list[npt.NDArray[Any] | pd.DataFrame]]:
12
18
  """
13
19
  Check for valid input values in numpy arrays or pandas data frames. Default behaviour is to
14
20
  identify missing values - assumed to be NaN and Inf. Other conditions
@@ -40,9 +46,8 @@ def filter_input_log(
40
46
  type_error = "filter_input_log: unknown input data type: {}".format(type(args))
41
47
  size_error = "filter_input_log: inputs of different length"
42
48
 
43
- if not isinstance(args, (list, tuple, np.ndarray, pd.DataFrame)):
44
- raise ValueError(type_error)
45
-
49
+ if not isinstance(args, (list, tuple, np.ndarray, pd.DataFrame)): # pyright: ignore[reportUnnecessaryIsInstance] | Kept for backward compatibility
50
+ raise ValueError(type_error) # pyright: ignore[reportUnreachable] | Kept for backward compatibility
46
51
  # Make sure that 'args' is iterable
47
52
  if isinstance(args, (np.ndarray, pd.DataFrame)):
48
53
  args = [args]
@@ -70,7 +75,7 @@ def filter_input_log(
70
75
  # https://github.com/pandas-dev/pandas/issues/32432
71
76
  # idx = ~logs.any(bool_only=True, axis=1)
72
77
  # Need to do it the cumbersome way for the time being
73
- bool_col = logs.dtypes.apply(lambda dtype: dtype == "bool")
78
+ bool_col = logs.dtypes == "bool"
74
79
  if any(bool_col):
75
80
  idx = ~logs.loc[:, logs.columns[bool_col]].any(axis=1)
76
81
  logs.drop(columns=logs.columns[bool_col], inplace=True)
@@ -1,12 +1,17 @@
1
- from sys import byteorder
1
+ from typing import Any
2
2
 
3
3
  import numpy as np
4
+ import numpy.typing as npt
4
5
  import pandas as pd
5
6
 
6
- WRONG_BYTEORDER = ">" if byteorder == "little" else "<"
7
7
 
8
-
9
- def filter_output(idx_inp, inp_log):
8
+ def filter_output(
9
+ idx_inp: npt.NDArray[np.bool_],
10
+ inp_log: list[npt.NDArray[Any] | pd.DataFrame]
11
+ | tuple[npt.NDArray[Any] | pd.DataFrame, ...]
12
+ | npt.NDArray[Any]
13
+ | pd.DataFrame,
14
+ ) -> list[npt.NDArray[Any] | pd.DataFrame]:
10
15
  """
11
16
  Function to restore outputs from a plugin to original length and
12
17
  with values at correct positions. The logs are assumed to go through
@@ -27,7 +32,9 @@ def filter_output(idx_inp, inp_log):
27
32
  Expanded inputs.
28
33
  """
29
34
 
30
- def _expand_array(idx, inp_single_log):
35
+ def _expand_array(
36
+ idx: npt.NDArray[np.bool_], inp_single_log: npt.NDArray[Any]
37
+ ) -> npt.NDArray[Any]:
31
38
  logs = np.ones(idx.shape, dtype=float) * np.nan
32
39
  try:
33
40
  logs[idx] = inp_single_log.flatten()
@@ -37,30 +44,33 @@ def filter_output(idx_inp, inp_log):
37
44
  logs[idx] = inp_single_log
38
45
  return logs.reshape(idx.shape)
39
46
 
40
- def _expand_df(idx, inp_df):
41
- logs = pd.DataFrame(columns=inp_df.columns, index=np.arange(idx.shape[0]))
47
+ def _expand_df(idx: npt.NDArray[np.bool_], inp_df: pd.DataFrame) -> pd.DataFrame:
48
+ logs = pd.DataFrame(
49
+ columns=inp_df.columns, index=np.arange(idx.shape[0], dtype=np.intp)
50
+ )
42
51
  logs.loc[idx] = inp_df
43
52
  return logs
44
53
 
45
- if not isinstance(inp_log, (list, tuple, np.ndarray, pd.DataFrame)):
46
- raise ValueError(
54
+ if not isinstance(inp_log, (list, tuple, np.ndarray, pd.DataFrame)): # pyright: ignore[reportUnnecessaryIsInstance] | Kept for backward compatibility
55
+ raise ValueError( # pyright: ignore[reportUnreachable] | Kept for backward compatibility
47
56
  "filter_output: unknown input data type: {}".format(type(inp_log))
48
57
  )
49
- if not isinstance(idx_inp, (list, np.ndarray)):
50
- raise ValueError(
58
+ if not isinstance(idx_inp, (list, np.ndarray)): # pyright: ignore[reportUnnecessaryIsInstance] | Kept for backward compatibility
59
+ raise ValueError( # pyright: ignore[reportUnreachable] | Kept for backward compatibility
51
60
  "filter_output: unknown filter array data type: {}".format(type(idx_inp))
52
61
  )
53
62
 
54
63
  # Make iterable in case of single input
55
64
  if isinstance(inp_log, (np.ndarray, pd.DataFrame)):
56
65
  inp_log = [inp_log]
57
- if isinstance(idx_inp, np.ndarray):
58
- idx_inp = [idx_inp]
66
+
67
+ if isinstance(idx_inp, np.ndarray): # pyright: ignore[reportUnnecessaryIsInstance] | Kept for backward compatibility
68
+ idx_inp_ = [idx_inp]
59
69
 
60
70
  # Possible to simplify?
61
- if len(idx_inp) != len(inp_log):
62
- if len(idx_inp) == 1:
63
- idx_inp = idx_inp * len(inp_log)
71
+ if len(idx_inp_) != len(inp_log):
72
+ if len(idx_inp_) == 1:
73
+ idx_inp_ = idx_inp_ * len(inp_log)
64
74
  else:
65
75
  raise ValueError(
66
76
  "filter_output: mismatch between length of filter arrays and inputs: {} and {}".format(
@@ -68,11 +78,11 @@ def filter_output(idx_inp, inp_log):
68
78
  )
69
79
  )
70
80
 
71
- return_logs = []
72
- for this_idx, this_log in zip(idx_inp, inp_log):
81
+ return_logs: list[npt.NDArray[Any] | pd.DataFrame] = []
82
+ for this_idx, this_log in zip(idx_inp_, inp_log):
73
83
  if isinstance(this_log, np.ndarray):
74
84
  return_logs.append(_expand_array(this_idx, this_log))
75
- elif isinstance(this_log, pd.DataFrame):
85
+ elif isinstance(this_log, pd.DataFrame): # pyright: ignore[reportUnnecessaryIsInstance] | Kept for backward compatibility
76
86
  return_logs.append(_expand_df(this_idx, this_log))
77
87
 
78
88
  return return_logs
@@ -1,14 +1,27 @@
1
1
  from .dummy_vars import generate_dummy_vars
2
- from .exponential_model import CarbonateExponentialPressure
2
+ from .exponential_model import ExponentialPressureModel
3
+ from .friable_pressure_models import (
4
+ FriableDryBulkModulusPressureModel,
5
+ FriableDryShearModulusPressureModel,
6
+ )
3
7
  from .import_ml_models import import_model
8
+ from .patchy_cement_pressure_models import (
9
+ PatchyCementDryBulkModulusPressureModel,
10
+ PatchyCementDryShearModulusPressureModel,
11
+ )
12
+ from .polynomial_model import PolynomialPressureModel
4
13
  from .run_regression import run_regression
5
- from .sigmoidal_model import CarbonateSigmoidalPressure, Sigmoid
14
+ from .sigmoidal_model import SigmoidalPressureModel
6
15
 
7
16
  __all__ = [
8
17
  "generate_dummy_vars",
9
- "CarbonateExponentialPressure",
10
18
  "import_model",
11
19
  "run_regression",
12
- "CarbonateSigmoidalPressure",
13
- "Sigmoid",
20
+ "ExponentialPressureModel",
21
+ "PolynomialPressureModel",
22
+ "SigmoidalPressureModel",
23
+ "FriableDryBulkModulusPressureModel",
24
+ "FriableDryShearModulusPressureModel",
25
+ "PatchyCementDryShearModulusPressureModel",
26
+ "PatchyCementDryBulkModulusPressureModel",
14
27
  ]
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import pickle
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Self
6
+
7
+ import numpy as np
8
+
9
+
10
+ class BasePressureModel(ABC):
11
+ """
12
+ Abstract base class for pressure sensitivity models.
13
+
14
+ All pressure models follow the convention:
15
+ - predict(): returns differential change (depleted - in_situ)
16
+ - predict_abs(): returns absolute values for specified case
17
+ - predict_max(): uses model_max_pressure instead of depleted pressure
18
+
19
+ Input validation is delegated to concrete implementations since
20
+ each model has different column requirements.
21
+ """
22
+
23
+ def __init__(self, model_max_pressure: float | None = None, description: str = ""):
24
+ """
25
+ Initialize base pressure model.
26
+
27
+ Parameters
28
+ ----------
29
+ model_max_pressure : float | None
30
+ Maximum pressure for predict_max method. Required for predict_max to work.
31
+ description : str
32
+ Human-readable description of the model instance.
33
+ """
34
+ self._model_max_pressure = model_max_pressure
35
+ self._description = description
36
+
37
+ @property
38
+ def max_pressure(self) -> float | None:
39
+ """Maximum pressure setting for predict_max method."""
40
+ return self._model_max_pressure
41
+
42
+ @property
43
+ def description(self) -> str:
44
+ """Model description."""
45
+ return self._description
46
+
47
+ def predict(self, inp_arr: np.ndarray) -> np.ndarray:
48
+ """
49
+ Predict differential change: result(depleted) - result(in_situ).
50
+
51
+ Parameters
52
+ ----------
53
+ inp_arr : np.ndarray
54
+ Input array with pressure columns and other model-specific parameters.
55
+
56
+ Returns
57
+ -------
58
+ np.ndarray
59
+ Differential change values.
60
+ """
61
+ arr = self.validate_input(inp_arr)
62
+ return self.predict_abs(arr, case="depleted") - self.predict_abs(
63
+ arr, case="in_situ"
64
+ )
65
+
66
+ def predict_max(self, inp_arr: np.ndarray) -> np.ndarray:
67
+ """
68
+ Predict using model_max_pressure instead of depleted pressure.
69
+
70
+ Parameters
71
+ ----------
72
+ inp_arr : np.ndarray
73
+ Input array where last column (depleted pressure) will be replaced.
74
+
75
+ Returns
76
+ -------
77
+ np.ndarray
78
+ Values at model_max_pressure minus values at in_situ pressure.
79
+
80
+ Raises
81
+ ------
82
+ ValueError
83
+ If model_max_pressure is not set.
84
+ """
85
+ if self._model_max_pressure is None:
86
+ raise ValueError('Field "model_max_pressure" is not set')
87
+
88
+ arr = self.validate_input(inp_arr).copy()
89
+ # Replace last column (assumed to be depleted pressure) with max pressure
90
+ arr[:, -1] = self._model_max_pressure
91
+ return self.predict_abs(arr, case="depleted") - self.predict_abs(
92
+ arr, case="in_situ"
93
+ )
94
+
95
+ @abstractmethod
96
+ def validate_input(self, inp_arr: np.ndarray) -> np.ndarray:
97
+ """
98
+ Validate input array format for this specific model.
99
+
100
+ Parameters
101
+ ----------
102
+ inp_arr : np.ndarray
103
+ Input array to validate.
104
+
105
+ Returns
106
+ -------
107
+ np.ndarray
108
+ Validated input array.
109
+
110
+ Raises
111
+ ------
112
+ ValueError
113
+ If input format is invalid for this model.
114
+ """
115
+
116
+ @abstractmethod
117
+ def predict_abs(self, inp_arr: np.ndarray, case: str = "in_situ") -> np.ndarray:
118
+ """
119
+ Predict absolute values for specified pressure case.
120
+
121
+ Parameters
122
+ ----------
123
+ inp_arr : np.ndarray
124
+ Validated input array.
125
+ case : str
126
+ Either "in_situ" or "depleted" to specify which pressure to use.
127
+
128
+ Returns
129
+ -------
130
+ np.ndarray
131
+ Absolute predicted values.
132
+ """
133
+
134
+ @abstractmethod
135
+ def todict(self) -> dict[str, Any]:
136
+ """
137
+ Convert model to dictionary for serialization.
138
+
139
+ Returns
140
+ -------
141
+ dict[str, Any]
142
+ Dictionary containing all model parameters.
143
+ """
144
+
145
+ def save(self, file: str | bytes) -> None:
146
+ """
147
+ Save model to pickle file.
148
+
149
+ Parameters
150
+ ----------
151
+ file : str | bytes
152
+ File path for saving.
153
+ """
154
+ with open(file, "wb") as f_out:
155
+ pickle.dump(self.todict(), f_out)
156
+
157
+ @classmethod
158
+ @abstractmethod
159
+ def load(cls, file: str | bytes) -> Self:
160
+ """
161
+ Load model from pickle file.
162
+
163
+ Parameters
164
+ ----------
165
+ file : str | bytes
166
+ File path for loading.
167
+
168
+ Returns
169
+ -------
170
+ BasePressureModel
171
+ Loaded model instance.
172
+ """
@@ -1,119 +1,133 @@
1
+ from __future__ import annotations
2
+
1
3
  import pickle
2
- from typing import Union
4
+ from typing import Any
3
5
 
4
6
  import numpy as np
5
7
 
8
+ from .base_pressure_model import BasePressureModel
6
9
 
7
- def _verify_input(inp_arr):
8
- if isinstance(inp_arr, np.ndarray) and not (
9
- inp_arr.ndim == 2 and inp_arr.shape[1] == 3
10
- ):
11
- raise ValueError(
12
- "Input to predict method should be an nx3 numpy array with columns velocity, in situ "
13
- "pressure and depleted pressure"
14
- )
15
10
 
11
+ class ExponentialPressureModel(BasePressureModel):
12
+ """
13
+ Exponential pressure sensitivity model for velocity prediction.
14
+
15
+ Uses exponential decay function: v = v0 * (1 - a*exp(-p/b)) / (1 - a*exp(-p0/b))
16
+ where v0 is reference velocity, p is pressure, a and b are model parameters.
17
+
18
+ Input format (n,3): [velocity, p_eff_in_situ, p_eff_depleted]
19
+ """
16
20
 
17
- class CarbonateExponentialPressure:
18
21
  def __init__(
19
22
  self,
20
- a_factor: float = None,
21
- b_factor: float = None,
22
- model_max_pressure: float = None,
23
+ a_factor: float,
24
+ b_factor: float,
25
+ model_max_pressure: float | None = None,
23
26
  description: str = "",
24
27
  ):
28
+ """
29
+ Initialize exponential pressure model.
30
+
31
+ Parameters
32
+ ----------
33
+ a_factor : float
34
+ Exponential amplitude parameter [unitless].
35
+ b_factor : float
36
+ Exponential decay parameter [Pa].
37
+ model_max_pressure : float | None
38
+ Maximum pressure for predict_max method [Pa].
39
+ description : str
40
+ Model description.
41
+ """
42
+ super().__init__(model_max_pressure, description)
25
43
  self._a_factor = a_factor
26
44
  self._b_factor = b_factor
27
- self._model_max_pressure = model_max_pressure
28
- self._description = description
29
-
30
- def todict(self):
31
- return {
32
- "a_factor": self._a_factor,
33
- "b_factor": self._b_factor,
34
- "model_max_pressure": self._model_max_pressure,
35
- "description": self._description,
36
- }
37
45
 
38
46
  @property
39
47
  def a_factor(self) -> float:
48
+ """Exponential amplitude factor."""
40
49
  return self._a_factor
41
50
 
42
51
  @property
43
52
  def b_factor(self) -> float:
53
+ """Exponential decay factor."""
44
54
  return self._b_factor
45
55
 
46
- @property
47
- def max_pressure(self) -> float:
48
- return self._model_max_pressure
56
+ def validate_input(self, inp_arr: np.ndarray) -> np.ndarray:
57
+ """
58
+ Validate input for exponential model.
59
+
60
+ Parameters
61
+ ----------
62
+ inp_arr : np.ndarray
63
+ Input array to validate.
64
+
65
+ Returns
66
+ -------
67
+ np.ndarray
68
+ Validated input array.
69
+
70
+ Raises
71
+ ------
72
+ ValueError
73
+ If input format is invalid.
74
+ """
75
+ if not isinstance(inp_arr, np.ndarray):
76
+ raise ValueError("Input must be numpy ndarray.")
77
+ if inp_arr.ndim != 2 or inp_arr.shape[1] != 3:
78
+ raise ValueError(
79
+ "Input must be (n,3): [velocity, p_eff_in_situ, p_eff_depleted]"
80
+ )
81
+ return inp_arr
49
82
 
50
- @property
51
- def description(self) -> str:
52
- return self._description
53
-
54
- def predict(self, inp_arr: np.ndarray) -> Union[np.ndarray, None]:
55
- _verify_input(inp_arr)
56
- if not self._valid():
57
- return None
58
- vel = inp_arr[:, 0]
59
- eff_pres_in_situ = inp_arr[:, 1]
60
- eff_pres_depl = inp_arr[:, 2]
61
- # Return differential velocity to match alternative models
62
- return (
63
- vel
64
- * (1.0 - self._a_factor * np.exp(-eff_pres_depl / self._b_factor))
65
- / (1.0 - self._a_factor * np.exp(-eff_pres_in_situ / self._b_factor))
66
- - vel
67
- )
83
+ def predict_abs(self, inp_arr: np.ndarray, case: str = "in_situ") -> np.ndarray:
84
+ """
85
+ Calculate absolute velocity for specified pressure case.
68
86
 
69
- def predict_max(self, inp_arr: np.ndarray) -> Union[np.ndarray, None]:
70
- _verify_input(inp_arr)
71
- if not self._valid():
72
- return None
73
- vel = inp_arr[:, 0]
74
- eff_pres_in_situ = inp_arr[:, 1]
75
- return (
76
- vel
77
- * (
78
- 1.0
79
- - self._a_factor * np.exp(-self._model_max_pressure / self._b_factor)
80
- )
81
- / (1.0 - self._a_factor * np.exp(-eff_pres_in_situ / self.b_factor))
82
- )
87
+ Parameters
88
+ ----------
89
+ inp_arr : np.ndarray
90
+ Validated input array (n,3).
91
+ case : str
92
+ Pressure case: "in_situ" or "depleted".
93
+
94
+ Returns
95
+ -------
96
+ np.ndarray
97
+ Velocity values [m/s].
98
+ """
99
+ arr = self.validate_input(inp_arr)
100
+
101
+ vel = arr[:, 0]
102
+ p_in_situ = arr[:, 1]
103
+ p_depleted = arr[:, 2]
104
+
105
+ p_eff = p_in_situ if case == "in_situ" else p_depleted
83
106
 
84
- def predict_abs(self, inp_arr: np.ndarray) -> Union[np.ndarray, None]:
85
- _verify_input(inp_arr)
86
- if not self._valid():
87
- return None
88
- vel = inp_arr[:, 0]
89
- eff_pres_in_situ = inp_arr[:, 1]
90
- eff_pres_depl = inp_arr[:, 2]
91
107
  return (
92
108
  vel
93
- * (1.0 - self._a_factor * np.exp(-eff_pres_depl / self._b_factor))
94
- / (1.0 - self._a_factor * np.exp(-eff_pres_in_situ / self._b_factor))
109
+ * (1.0 - self._a_factor * np.exp(-p_eff / self._b_factor))
110
+ / (1.0 - self._a_factor * np.exp(-p_in_situ / self._b_factor))
95
111
  )
96
112
 
97
- def save(self, file):
98
- with open(file, "wb") as f_out:
99
- pickle.dump(self.todict(), f_out)
113
+ def todict(self) -> dict[str, Any]:
114
+ """Convert model to dictionary."""
115
+ return {
116
+ "a_factor": self._a_factor,
117
+ "b_factor": self._b_factor,
118
+ "model_max_pressure": self._model_max_pressure,
119
+ "description": self._description,
120
+ }
100
121
 
101
122
  @classmethod
102
- def load(cls, file):
123
+ def load(cls, file: str | bytes) -> "ExponentialPressureModel":
124
+ """Load exponential model from pickle file."""
103
125
  with open(file, "rb") as f_in:
104
- inp_pcl = pickle.load(f_in)
105
- return cls(
106
- a_factor=inp_pcl["a_factor"],
107
- b_factor=inp_pcl["b_factor"],
108
- model_max_pressure=inp_pcl["model_max_pressure"],
109
- description=inp_pcl["description"],
110
- )
126
+ d = pickle.load(f_in)
111
127
 
112
- def _valid(self):
113
- if self.a_factor is None:
114
- raise ValueError('object field "a_factor" is not set')
115
- if self.b_factor is None:
116
- raise ValueError('object field "b_factor" is not set')
117
- if self.max_pressure is None:
118
- raise ValueError('object field "max_pressure" is not set')
119
- return True
128
+ return cls(
129
+ a_factor=d["a_factor"],
130
+ b_factor=d["b_factor"],
131
+ model_max_pressure=d["model_max_pressure"],
132
+ description=d["description"],
133
+ )