ert 19.0.0rc4__py3-none-any.whl → 20.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. ert/__main__.py +94 -63
  2. ert/analysis/_es_update.py +11 -14
  3. ert/config/__init__.py +3 -2
  4. ert/config/_create_observation_dataframes.py +51 -375
  5. ert/config/_observations.py +483 -200
  6. ert/config/_read_summary.py +4 -5
  7. ert/config/ert_config.py +53 -80
  8. ert/config/everest_control.py +40 -39
  9. ert/config/everest_response.py +1 -13
  10. ert/config/field.py +0 -72
  11. ert/config/forward_model_step.py +17 -1
  12. ert/config/gen_data_config.py +14 -17
  13. ert/config/observation_config_migrations.py +821 -0
  14. ert/config/parameter_config.py +18 -28
  15. ert/config/parsing/__init__.py +0 -1
  16. ert/config/parsing/_parse_zonemap.py +45 -0
  17. ert/config/parsing/config_keywords.py +1 -1
  18. ert/config/parsing/config_schema.py +2 -8
  19. ert/config/parsing/observations_parser.py +2 -0
  20. ert/config/response_config.py +5 -23
  21. ert/config/rft_config.py +44 -19
  22. ert/config/summary_config.py +1 -13
  23. ert/config/surface_config.py +0 -57
  24. ert/dark_storage/compute/misfits.py +0 -42
  25. ert/dark_storage/endpoints/__init__.py +0 -2
  26. ert/dark_storage/endpoints/experiments.py +2 -5
  27. ert/dark_storage/json_schema/experiment.py +1 -2
  28. ert/field_utils/__init__.py +0 -2
  29. ert/field_utils/field_utils.py +1 -117
  30. ert/gui/ertwidgets/listeditbox.py +9 -1
  31. ert/gui/ertwidgets/models/ertsummary.py +20 -6
  32. ert/gui/ertwidgets/pathchooser.py +9 -1
  33. ert/gui/ertwidgets/stringbox.py +11 -3
  34. ert/gui/ertwidgets/textbox.py +10 -3
  35. ert/gui/ertwidgets/validationsupport.py +19 -1
  36. ert/gui/main_window.py +11 -6
  37. ert/gui/simulation/experiment_panel.py +1 -1
  38. ert/gui/simulation/run_dialog.py +11 -1
  39. ert/gui/tools/manage_experiments/export_dialog.py +4 -0
  40. ert/gui/tools/manage_experiments/manage_experiments_panel.py +1 -0
  41. ert/gui/tools/manage_experiments/storage_info_widget.py +5 -2
  42. ert/gui/tools/manage_experiments/storage_widget.py +18 -3
  43. ert/gui/tools/plot/data_type_proxy_model.py +1 -1
  44. ert/gui/tools/plot/plot_api.py +35 -27
  45. ert/gui/tools/plot/plot_widget.py +5 -0
  46. ert/gui/tools/plot/plot_window.py +4 -7
  47. ert/run_models/ensemble_experiment.py +1 -3
  48. ert/run_models/ensemble_smoother.py +1 -3
  49. ert/run_models/everest_run_model.py +12 -13
  50. ert/run_models/initial_ensemble_run_model.py +19 -22
  51. ert/run_models/model_factory.py +7 -7
  52. ert/run_models/multiple_data_assimilation.py +1 -3
  53. ert/sample_prior.py +12 -14
  54. ert/services/__init__.py +7 -3
  55. ert/services/_storage_main.py +59 -22
  56. ert/services/ert_server.py +186 -24
  57. ert/shared/version.py +3 -3
  58. ert/storage/local_ensemble.py +46 -115
  59. ert/storage/local_experiment.py +0 -16
  60. ert/utils/__init__.py +20 -0
  61. ert/warnings/specific_warning_handler.py +3 -2
  62. {ert-19.0.0rc4.dist-info → ert-20.0.0b0.dist-info}/METADATA +4 -51
  63. {ert-19.0.0rc4.dist-info → ert-20.0.0b0.dist-info}/RECORD +75 -80
  64. everest/bin/everest_script.py +5 -5
  65. everest/bin/kill_script.py +2 -2
  66. everest/bin/monitor_script.py +2 -2
  67. everest/bin/utils.py +4 -4
  68. everest/detached/everserver.py +6 -6
  69. everest/gui/everest_client.py +0 -6
  70. everest/gui/main_window.py +2 -2
  71. everest/util/__init__.py +1 -19
  72. ert/dark_storage/compute/__init__.py +0 -0
  73. ert/dark_storage/endpoints/compute/__init__.py +0 -0
  74. ert/dark_storage/endpoints/compute/misfits.py +0 -95
  75. ert/services/_base_service.py +0 -387
  76. ert/services/webviz_ert_service.py +0 -20
  77. ert/shared/storage/command.py +0 -38
  78. ert/shared/storage/extraction.py +0 -42
  79. {ert-19.0.0rc4.dist-info → ert-20.0.0b0.dist-info}/WHEEL +0 -0
  80. {ert-19.0.0rc4.dist-info → ert-20.0.0b0.dist-info}/entry_points.txt +0 -0
  81. {ert-19.0.0rc4.dist-info → ert-20.0.0b0.dist-info}/licenses/COPYING +0 -0
  82. {ert-19.0.0rc4.dist-info → ert-20.0.0b0.dist-info}/top_level.txt +0 -0
@@ -18,9 +18,12 @@ from pandas.api.types import is_numeric_dtype
18
18
  from pandas.errors import ParserError
19
19
  from resfo_utilities import history_key
20
20
 
21
- from ert.config import ParameterConfig, ResponseMetadata
22
- from ert.services import ErtServer
21
+ from ert.config import ParameterConfig
22
+ from ert.config.ensemble_config import ResponseConfig
23
+ from ert.config.known_response_types import KnownResponseTypes
24
+ from ert.services import create_ertserver_client
23
25
  from ert.storage.local_experiment import _parameters_adapter as parameter_config_adapter
26
+ from ert.storage.local_experiment import _responses_adapter as response_config_adapter
24
27
  from ert.storage.realization_storage_state import RealizationStorageState
25
28
 
26
29
  logger = logging.getLogger(__name__)
@@ -46,18 +49,18 @@ class PlotApiKeyDefinition(NamedTuple):
46
49
  metadata: dict[Any, Any]
47
50
  filter_on: dict[Any, Any] | None = None
48
51
  parameter: ParameterConfig | None = None
49
- response_metadata: ResponseMetadata | None = None
52
+ response: ResponseConfig | None = None
50
53
 
51
54
 
52
55
  class PlotApi:
53
56
  def __init__(self, ens_path: Path) -> None:
54
- self.ens_path = ens_path
57
+ self.ens_path: Path = ens_path
55
58
  self._all_ensembles: list[EnsembleObject] | None = None
56
59
  self._timeout = 120
57
60
 
58
61
  @property
59
62
  def api_version(self) -> str:
60
- with ErtServer.session(project=self.ens_path) as client:
63
+ with create_ertserver_client(self.ens_path) as client:
61
64
  try:
62
65
  response = client.get("/version", timeout=self._timeout)
63
66
  self._check_response(response)
@@ -83,7 +86,7 @@ class PlotApi:
83
86
  return self._all_ensembles
84
87
 
85
88
  self._all_ensembles = []
86
- with ErtServer.session(project=self.ens_path) as client:
89
+ with create_ertserver_client(self.ens_path) as client:
87
90
  try:
88
91
  response = client.get("/experiments", timeout=self._timeout)
89
92
  self._check_response(response)
@@ -139,7 +142,7 @@ class PlotApi:
139
142
  all_keys: dict[str, PlotApiKeyDefinition] = {}
140
143
  all_params = {}
141
144
 
142
- with ErtServer.session(project=self.ens_path) as client:
145
+ with create_ertserver_client(self.ens_path) as client:
143
146
  response = client.get("/experiments", timeout=self._timeout)
144
147
  self._check_response(response)
145
148
 
@@ -166,7 +169,7 @@ class PlotApi:
166
169
  def responses_api_key_defs(self) -> list[PlotApiKeyDefinition]:
167
170
  key_defs: dict[str, PlotApiKeyDefinition] = {}
168
171
 
169
- with ErtServer.session(project=self.ens_path) as client:
172
+ with create_ertserver_client(self.ens_path) as client:
170
173
  response = client.get("/experiments", timeout=self._timeout)
171
174
  self._check_response(response)
172
175
 
@@ -176,22 +179,26 @@ class PlotApi:
176
179
  key_defs[plot_key_def.key] = plot_key_def
177
180
 
178
181
  for experiment in response.json():
179
- for response_type, response_metadatas in experiment[
180
- "responses"
181
- ].items():
182
- for metadata in response_metadatas:
183
- key = metadata["response_key"]
182
+ for response_type, metadata in experiment["responses"].items():
183
+ response_config: KnownResponseTypes = (
184
+ response_config_adapter.validate_python(metadata)
185
+ )
186
+ keys = response_config.keys
187
+ for key in keys:
184
188
  has_obs = (
185
189
  response_type in experiment["observations"]
186
190
  and key in experiment["observations"][response_type]
187
191
  )
188
- if metadata["filter_on"]:
192
+ if response_config.filter_on is not None:
189
193
  # Only assume one filter_on, this code is to be
190
194
  # considered a bit "temp".
191
195
  # In general, we could create a dropdown per
192
196
  # filter_on on the frontend side
193
- for filter_key, values in metadata["filter_on"].items():
197
+
198
+ filter_for_key = response_config.filter_on.get(key, {})
199
+ for filter_key, values in filter_for_key.items():
194
200
  for v in values:
201
+ filter_on = {filter_key: v}
195
202
  subkey = f"{key}@{v}"
196
203
  update_keydef(
197
204
  PlotApiKeyDefinition(
@@ -202,10 +209,8 @@ class PlotApi:
202
209
  metadata={
203
210
  "data_origin": response_type,
204
211
  },
205
- filter_on={filter_key: v},
206
- response_metadata=ResponseMetadata(
207
- **metadata
208
- ),
212
+ filter_on=filter_on,
213
+ response=response_config,
209
214
  )
210
215
  )
211
216
  else:
@@ -216,7 +221,7 @@ class PlotApi:
216
221
  observations=has_obs,
217
222
  dimensionality=2,
218
223
  metadata={"data_origin": response_type},
219
- response_metadata=ResponseMetadata(**metadata),
224
+ response=response_config,
220
225
  )
221
226
  )
222
227
 
@@ -228,7 +233,9 @@ class PlotApi:
228
233
  response_key: str,
229
234
  filter_on: dict[str, Any] | None = None,
230
235
  ) -> pd.DataFrame:
231
- with ErtServer.session(project=self.ens_path) as client:
236
+ if "@" in response_key:
237
+ response_key = response_key.split("@", maxsplit=1)[0]
238
+ with create_ertserver_client(self.ens_path) as client:
232
239
  response = client.get(
233
240
  f"/ensembles/{ensemble_id}/responses/{PlotApi.escape(response_key)}",
234
241
  headers={"accept": "application/x-parquet"},
@@ -256,7 +263,7 @@ class PlotApi:
256
263
  return df
257
264
 
258
265
  def data_for_parameter(self, ensemble_id: str, parameter_key: str) -> pd.DataFrame:
259
- with ErtServer.session(project=self.ens_path) as client:
266
+ with create_ertserver_client(self.ens_path) as client:
260
267
  parameter = client.get(
261
268
  f"/ensembles/{ensemble_id}/parameters/{PlotApi.escape(parameter_key)}",
262
269
  headers={"accept": "application/x-parquet"},
@@ -294,11 +301,12 @@ class PlotApi:
294
301
  )
295
302
  if not key_def:
296
303
  raise httpx.RequestError(f"Response key {key_def} not found")
297
-
298
- assert key_def.response_metadata is not None
299
- actual_response_key = key_def.response_metadata.response_key
304
+ assert key_def.response is not None
305
+ actual_response_key = key
306
+ if "@" in actual_response_key:
307
+ actual_response_key = key.split("@", maxsplit=1)[0]
300
308
  filter_on = key_def.filter_on
301
- with ErtServer.session(project=self.ens_path) as client:
309
+ with create_ertserver_client(self.ens_path) as client:
302
310
  response = client.get(
303
311
  f"/ensembles/{ensemble.id}/responses/{PlotApi.escape(actual_response_key)}/observations",
304
312
  timeout=self._timeout,
@@ -386,7 +394,7 @@ class PlotApi:
386
394
  if not ensemble:
387
395
  return np.array([])
388
396
 
389
- with ErtServer.session(project=self.ens_path) as client:
397
+ with create_ertserver_client(self.ens_path) as client:
390
398
  response = client.get(
391
399
  f"/ensembles/{ensemble.id}/parameters/{PlotApi.escape(key)}/std_dev",
392
400
  params={"z": z},
@@ -153,6 +153,7 @@ class PlotWidget(QWidget):
153
153
  # only for histogram plot see _sync_log_checkbox
154
154
  self._log_checkbox.setVisible(False)
155
155
  self._log_checkbox.setToolTip("Toggle data domain to log scale and back")
156
+ self._log_checkbox.clicked.connect(self.logLogScaleButtonUsage)
156
157
 
157
158
  log_checkbox_row = QHBoxLayout()
158
159
  log_checkbox_row.addWidget(self._log_checkbox)
@@ -193,6 +194,10 @@ class PlotWidget(QWidget):
193
194
  def name(self) -> str:
194
195
  return self._name
195
196
 
197
+ def logLogScaleButtonUsage(self) -> None:
198
+ logger.info(f"Plotwidget utility used: 'Log scale button' in tab '{self.name}'")
199
+ self._log_checkbox.clicked.disconnect() # Log only once
200
+
196
201
  def updatePlot(
197
202
  self,
198
203
  plot_context: "PlotContext",
@@ -25,7 +25,7 @@ from PyQt6.QtWidgets import (
25
25
  from ert.config.field import Field
26
26
  from ert.dark_storage.common import get_storage_api_version
27
27
  from ert.gui.ertwidgets import CopyButton, showWaitCursorWhileWaiting
28
- from ert.services._base_service import ServerBootFail
28
+ from ert.services import ServerBootFail
29
29
  from ert.utils import log_duration
30
30
 
31
31
  from .customize import PlotCustomizer
@@ -267,10 +267,10 @@ class PlotWindow(QMainWindow):
267
267
  ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame] = {}
268
268
  for ensemble in selected_ensembles:
269
269
  try:
270
- if key_def.response_metadata is not None:
270
+ if key_def.response is not None:
271
271
  ensemble_to_data_map[ensemble] = self._api.data_for_response(
272
272
  ensemble_id=ensemble.id,
273
- response_key=key_def.response_metadata.response_key,
273
+ response_key=key,
274
274
  filter_on=key_def.filter_on,
275
275
  )
276
276
  elif (
@@ -348,10 +348,7 @@ class PlotWindow(QMainWindow):
348
348
  handle_exception(e)
349
349
  plot_context.history_data = None
350
350
 
351
- if (
352
- key_def.response_metadata is not None
353
- and key_def.response_metadata.response_type == "rft"
354
- ):
351
+ if key_def.response is not None and key_def.response.type == "rft":
355
352
  plot_context.setXLabel(key.split(":")[-1])
356
353
  plot_context.setYLabel("TVD")
357
354
  plot_context.depth_y_axis = True
@@ -55,9 +55,7 @@ class EnsembleExperiment(InitialEnsembleRunModel, EnsembleExperimentConfig):
55
55
 
56
56
  experiment_storage = self._storage.create_experiment(
57
57
  parameters=cast(list[ParameterConfig], self.parameter_configuration),
58
- observations={k: v.to_polars() for k, v in self.observations.items()}
59
- if self.observations is not None
60
- else None,
58
+ observations=self.observation_dataframes(),
61
59
  responses=cast(list[ResponseConfig], self.response_configuration),
62
60
  name=self.experiment_name,
63
61
  templates=self.ert_templates,
@@ -49,9 +49,7 @@ class EnsembleSmoother(InitialEnsembleRunModel, UpdateRunModel, EnsembleSmoother
49
49
 
50
50
  experiment_storage = self._storage.create_experiment(
51
51
  parameters=cast(list[ParameterConfig], self.parameter_configuration),
52
- observations={k: v.to_polars() for k, v in self.observations.items()}
53
- if self.observations is not None
54
- else None,
52
+ observations=self.observation_dataframes(),
55
53
  responses=cast(list[ResponseConfig], self.response_configuration),
56
54
  name=self.experiment_name,
57
55
  templates=self.ert_templates,
@@ -831,19 +831,18 @@ class EverestRunModel(RunModel, EverestRunModelConfig):
831
831
 
832
832
  ensemble.save_everest_realization_info(realization_info)
833
833
 
834
- for sim_id in range(sim_to_control_vector.shape[0]):
835
- sim_controls = sim_to_control_vector[sim_id]
836
- offset = 0
837
- for control_config in self._everest_control_configs:
838
- n_param_keys = len(control_config.parameter_keys)
839
-
840
- # Save controls to ensemble
841
- ensemble.save_parameters_numpy(
842
- sim_controls[offset : (offset + n_param_keys)].reshape(-1, 1),
843
- control_config.name,
844
- np.array([sim_id]),
845
- )
846
- offset += n_param_keys
834
+ iens = sim_to_control_vector.shape[0]
835
+ offset = 0
836
+ for control_config in self._everest_control_configs:
837
+ n_param_keys = len(control_config.parameter_keys)
838
+ name = control_config.name
839
+ parameters = sim_to_control_vector[:, offset : offset + n_param_keys]
840
+ ensemble.save_parameters_numpy(
841
+ parameters.reshape(-1, n_param_keys),
842
+ name,
843
+ np.arange(iens),
844
+ )
845
+ offset += n_param_keys
847
846
 
848
847
  # Evaluate the batch:
849
848
  run_args = self._get_run_args(
@@ -3,15 +3,17 @@ from typing import Annotated, Any, Literal, Self
3
3
  import numpy as np
4
4
  import polars as pl
5
5
  from polars.datatypes import DataTypeClass
6
- from pydantic import BaseModel, Field, field_validator
6
+ from pydantic import BaseModel, Field
7
7
 
8
8
  from ert.config import (
9
9
  EverestControl,
10
10
  GenKwConfig,
11
11
  KnownResponseTypes,
12
+ Observation,
12
13
  SurfaceConfig,
13
14
  )
14
15
  from ert.config import Field as FieldConfig
16
+ from ert.config._create_observation_dataframes import create_observation_dataframes
15
17
  from ert.ensemble_evaluator.config import EvaluatorServerConfig
16
18
  from ert.run_arg import create_run_arguments
17
19
  from ert.run_models.run_model import RunModel, RunModelConfig
@@ -67,27 +69,7 @@ class InitialEnsembleRunModelConfig(RunModelConfig):
67
69
  ]
68
70
  ]
69
71
  ert_templates: list[tuple[str, str]]
70
- observations: dict[str, DictEncodedDataFrame] | None = None
71
-
72
- @field_validator("observations", mode="before")
73
- @classmethod
74
- def make_dict_encoded_observations(
75
- cls, v: dict[str, pl.DataFrame | DictEncodedDataFrame | dict[str, Any]] | None
76
- ) -> dict[str, DictEncodedDataFrame] | None:
77
- if v is None:
78
- return None
79
-
80
- encoded = {}
81
- for k, df in v.items():
82
- match df:
83
- case DictEncodedDataFrame():
84
- encoded[k] = df
85
- case pl.DataFrame():
86
- encoded[k] = DictEncodedDataFrame.from_polars(df)
87
- case dict():
88
- encoded[k] = DictEncodedDataFrame.model_validate(df)
89
-
90
- return encoded
72
+ observations: list[Observation] | None = None
91
73
 
92
74
 
93
75
  class InitialEnsembleRunModel(RunModel, InitialEnsembleRunModelConfig):
@@ -101,6 +83,7 @@ class InitialEnsembleRunModel(RunModel, InitialEnsembleRunModelConfig):
101
83
  np.where(self.active_realizations)[0],
102
84
  parameters=[param.name for param in self.parameter_configuration],
103
85
  random_seed=self.random_seed,
86
+ num_realizations=self.runpath_config.num_realizations,
104
87
  design_matrix_df=self.design_matrix.to_polars()
105
88
  if self.design_matrix is not None
106
89
  else None,
@@ -117,3 +100,17 @@ class InitialEnsembleRunModel(RunModel, InitialEnsembleRunModelConfig):
117
100
  evaluator_server_config,
118
101
  )
119
102
  return ensemble_storage
103
+
104
+ def observation_dataframes(self) -> dict[str, pl.DataFrame]:
105
+ if self.observations is None:
106
+ return {}
107
+
108
+ rft_config = next(
109
+ (r for r in self.response_configuration if r.type == "rft"),
110
+ None,
111
+ )
112
+
113
+ return create_observation_dataframes(
114
+ observations=self.observations,
115
+ rft_config=rft_config,
116
+ )
@@ -153,7 +153,7 @@ def _setup_single_test_run(
153
153
  log_path=config.analysis_config.log_path,
154
154
  storage_path=config.ens_path,
155
155
  queue_config=config.queue_config.create_local_copy(),
156
- observations=config.observations,
156
+ observations=config.observation_declarations,
157
157
  )
158
158
 
159
159
  return SingleTestRun(
@@ -212,7 +212,7 @@ def _setup_ensemble_experiment(
212
212
  log_path=config.analysis_config.log_path,
213
213
  storage_path=config.ens_path,
214
214
  queue_config=config.queue_config,
215
- observations=config.observations,
215
+ observations=config.observation_declarations,
216
216
  )
217
217
 
218
218
  return EnsembleExperiment(
@@ -305,7 +305,7 @@ def _setup_manual_update(
305
305
  hooked_workflows=config.hooked_workflows,
306
306
  log_path=config.analysis_config.log_path,
307
307
  ert_templates=config.ert_templates,
308
- observations=config.observations,
308
+ observations=config.observation_declarations,
309
309
  )
310
310
  return ManualUpdate(**runmodel_config.model_dump(), status_queue=status_queue)
311
311
 
@@ -343,7 +343,7 @@ def _setup_manual_update_enif(
343
343
  substitutions=config.substitutions,
344
344
  hooked_workflows=config.hooked_workflows,
345
345
  log_path=config.analysis_config.log_path,
346
- observations=config.observations,
346
+ observations=config.observation_declarations,
347
347
  )
348
348
 
349
349
 
@@ -389,7 +389,7 @@ def _setup_ensemble_smoother(
389
389
  substitutions=config.substitutions,
390
390
  hooked_workflows=config.hooked_workflows,
391
391
  log_path=config.analysis_config.log_path,
392
- observations=config.observations,
392
+ observations=config.observation_declarations,
393
393
  )
394
394
  return EnsembleSmoother(**runmodel_config.model_dump(), status_queue=status_queue)
395
395
 
@@ -435,7 +435,7 @@ def _setup_ensemble_information_filter(
435
435
  substitutions=config.substitutions,
436
436
  hooked_workflows=config.hooked_workflows,
437
437
  log_path=config.analysis_config.log_path,
438
- observations=config.observations,
438
+ observations=config.observation_declarations,
439
439
  )
440
440
  return EnsembleInformationFilter(
441
441
  **runmodel_config.model_dump(), status_queue=status_queue
@@ -507,7 +507,7 @@ def _setup_multiple_data_assimilation(
507
507
  substitutions=config.substitutions,
508
508
  hooked_workflows=config.hooked_workflows,
509
509
  log_path=config.analysis_config.log_path,
510
- observations=config.observations,
510
+ observations=config.observation_declarations,
511
511
  )
512
512
  return MultipleDataAssimilation(
513
513
  **runmodel_config.model_dump(), status_queue=status_queue
@@ -116,9 +116,7 @@ class MultipleDataAssimilation(
116
116
  sim_args = {"weights": self.weights}
117
117
  experiment_storage = self._storage.create_experiment(
118
118
  parameters=cast(list[ParameterConfig], self.parameter_configuration),
119
- observations={k: v.to_polars() for k, v in self.observations.items()}
120
- if self.observations is not None
121
- else None,
119
+ observations=self.observation_dataframes(),
122
120
  responses=cast(list[ResponseConfig], self.response_configuration),
123
121
  name=self.experiment_name,
124
122
  templates=self.ert_templates,
ert/sample_prior.py CHANGED
@@ -21,6 +21,7 @@ def sample_prior(
21
21
  ensemble: Ensemble,
22
22
  active_realizations: Iterable[int],
23
23
  random_seed: int,
24
+ num_realizations: int,
24
25
  parameters: list[str] | None = None,
25
26
  design_matrix_df: pl.DataFrame | None = None,
26
27
  ) -> None:
@@ -65,21 +66,18 @@ def sample_prior(
65
66
  f"Sampling parameter {config_node.name} "
66
67
  f"for realizations {active_realizations}"
67
68
  )
68
- datasets = [
69
- Ensemble.sample_parameter(
70
- config_node,
71
- realization_nr,
72
- random_seed=random_seed,
73
- )
74
- for realization_nr in active_realizations
75
- ]
76
- if datasets:
77
- dataset = pl.concat(datasets, how="vertical")
69
+ dataset = Ensemble.sample_parameter(
70
+ config_node,
71
+ list(active_realizations),
72
+ random_seed=random_seed,
73
+ num_realizations=num_realizations,
74
+ )
75
+ if not (dataset is None or dataset.is_empty()):
76
+ if complete_dataset is None:
77
+ complete_dataset = dataset
78
+ elif dataset is not None:
79
+ complete_dataset = complete_dataset.join(dataset, on="realization")
78
80
 
79
- if complete_dataset is None:
80
- complete_dataset = dataset
81
- elif dataset is not None:
82
- complete_dataset = complete_dataset.join(dataset, on="realization")
83
81
  else:
84
82
  for realization_nr in active_realizations:
85
83
  ds = config_node.read_from_runpath(Path(), realization_nr, 0)
ert/services/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- from .ert_server import ErtServer
2
- from .webviz_ert_service import WebvizErt
1
+ from .ert_server import (
2
+ ErtServer,
3
+ ErtServerExit,
4
+ ServerBootFail,
5
+ create_ertserver_client,
6
+ )
3
7
 
4
- __all__ = ["ErtServer", "WebvizErt"]
8
+ __all__ = ["ErtServer", "ErtServerExit", "ServerBootFail", "create_ertserver_client"]
@@ -13,6 +13,7 @@ import sys
13
13
  import threading
14
14
  import time
15
15
  import warnings
16
+ from argparse import ArgumentParser
16
17
  from base64 import b64encode
17
18
  from pathlib import Path
18
19
  from typing import Any
@@ -29,12 +30,11 @@ from uvicorn.supervisors import ChangeReload
29
30
 
30
31
  from ert.logging import STORAGE_LOG_CONFIG
31
32
  from ert.plugins import setup_site_logging
32
- from ert.services._base_service import BaseServiceExit
33
+ from ert.services import ErtServerExit
33
34
  from ert.shared import __file__ as ert_shared_path
34
35
  from ert.shared import find_available_socket, get_machine_name
35
- from ert.shared.storage.command import add_parser_options
36
36
  from ert.trace import tracer
37
- from everest.util import makedirs_if_needed
37
+ from ert.utils import makedirs_if_needed
38
38
 
39
39
  DARK_STORAGE_APP = "ert.dark_storage.app:app"
40
40
 
@@ -82,7 +82,7 @@ def _get_host_list() -> list[str]:
82
82
 
83
83
 
84
84
  def _create_connection_info(
85
- sock: socket.socket, authtoken: str, cert: str | os.PathLike[str]
85
+ sock: socket.socket, authtoken: str, cert: str | os.PathLike[str] | Path
86
86
  ) -> dict[str, Any]:
87
87
  connection_info = {
88
88
  "urls": [
@@ -91,7 +91,7 @@ def _create_connection_info(
91
91
  "authtoken": authtoken,
92
92
  "host": get_machine_name(),
93
93
  "port": sock.getsockname()[1],
94
- "cert": cert,
94
+ "cert": str(cert),
95
95
  "auth": authtoken,
96
96
  }
97
97
 
@@ -102,14 +102,17 @@ def _create_connection_info(
102
102
  return connection_info
103
103
 
104
104
 
105
- def _generate_certificate(cert_folder: str) -> tuple[str, str, bytes]:
105
+ def _generate_certificate(cert_folder: Path) -> tuple[Path, Path, bytes]:
106
106
  """Generate a private key and a certificate signed with it
107
107
 
108
108
  Both the certificate and the key are written to files in the folder given
109
109
  by `get_certificate_dir(config)`. The key is encrypted before being
110
110
  stored.
111
- Returns the path to the certificate file, the path to the key file, and
112
- the password used for encrypting the key
111
+
112
+ Returns a 3-tuple with
113
+ * Certificate file path
114
+ * Key file path
115
+ * Password used for encrypting the key
113
116
  """
114
117
  # Generate private key
115
118
  key = rsa.generate_private_key(
@@ -150,11 +153,11 @@ def _generate_certificate(cert_folder: str) -> tuple[str, str, bytes]:
150
153
 
151
154
  # Write certificate and key to disk
152
155
  makedirs_if_needed(cert_folder)
153
- cert_path = os.path.join(cert_folder, dns_name + ".crt")
154
- Path(cert_path).write_bytes(cert.public_bytes(serialization.Encoding.PEM))
155
- key_path = os.path.join(cert_folder, dns_name + ".key")
156
+ cert_path = cert_folder / f"{dns_name}.crt"
157
+ cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
158
+ key_path = cert_folder / f"{dns_name}.key"
156
159
  pw = bytes(os.urandom(28))
157
- Path(key_path).write_bytes(
160
+ key_path.write_bytes(
158
161
  key.private_bytes(
159
162
  encoding=serialization.Encoding.PEM,
160
163
  format=serialization.PrivateFormat.TraditionalOpenSSL,
@@ -184,16 +187,16 @@ def run_server(
184
187
 
185
188
  config_args: dict[str, Any] = {}
186
189
  if args.debug or debug:
187
- config_args.update(reload=True, reload_dirs=[os.path.dirname(ert_shared_path)])
190
+ config_args.update(reload=True, reload_dirs=[Path(ert_shared_path).parent])
188
191
  os.environ["ERT_STORAGE_DEBUG"] = "1"
189
192
 
190
- sock = find_available_socket(
193
+ sock: socket.socket = find_available_socket(
191
194
  host=get_machine_name(), port_range=range(51850, 51870 + 1)
192
195
  )
193
196
 
194
197
  # Appropriated from uvicorn.main:run
195
198
  os.environ["ERT_STORAGE_NO_TOKEN"] = "1"
196
- os.environ["ERT_STORAGE_ENS_PATH"] = os.path.abspath(args.project)
199
+ os.environ["ERT_STORAGE_ENS_PATH"] = str(args.project.absolute())
197
200
  config = (
198
201
  # uvicorn.Config() resets the logging config (overriding additional
199
202
  # handlers added to loggers like e.g. the ert_azurelogger handler
@@ -253,11 +256,11 @@ def _join_terminate_thread(terminate_on_parent_death_thread: threading.Thread) -
253
256
  """Join the terminate thread, handling BaseServiceExit (which is used by Everest)"""
254
257
  try:
255
258
  terminate_on_parent_death_thread.join()
256
- except BaseServiceExit:
259
+ except ErtServerExit:
257
260
  logger = logging.getLogger("ert.shared.storage.info")
258
261
  logger.info(
259
262
  "Got BaseServiceExit while joining terminate thread, "
260
- "as expected from _base_service.py"
263
+ "as expected from ert_server.py"
261
264
  )
262
265
 
263
266
 
@@ -265,9 +268,7 @@ def main() -> None:
265
268
  args = parse_args()
266
269
  authentication = _generate_authentication()
267
270
  os.environ["ERT_STORAGE_TOKEN"] = authentication
268
- cert_path, key_path, key_pw = _generate_certificate(
269
- os.path.join(args.project, "cert")
270
- )
271
+ cert_path, key_path, key_pw = _generate_certificate(args.project / "cert")
271
272
  config_args: dict[str, Any] = {
272
273
  "ssl_keyfile": key_path,
273
274
  "ssl_certfile": cert_path,
@@ -283,7 +284,7 @@ def main() -> None:
283
284
  warnings.filterwarnings("ignore", category=DeprecationWarning)
284
285
 
285
286
  if args.debug:
286
- config_args.update(reload=True, reload_dirs=[os.path.dirname(ert_shared_path)])
287
+ config_args.update(reload=True, reload_dirs=[Path(ert_shared_path).parent])
287
288
 
288
289
  # Need to run uvicorn.Config before entering the ErtPluginContext because
289
290
  # uvicorn.Config overrides the configuration of existing loggers, thus removing
@@ -310,12 +311,48 @@ def main() -> None:
310
311
  logger.info("Starting dark storage")
311
312
  logger.info(f"Started dark storage with parent {args.parent_pid}")
312
313
  run_server(args, debug=False, uvicorn_config=uvicorn_config)
313
- except (SystemExit, BaseServiceExit):
314
+ except (SystemExit, ErtServerExit):
314
315
  logger.info("Stopping dark storage")
315
316
  finally:
316
317
  stopped.set()
317
318
  _join_terminate_thread(terminate_on_parent_death_thread)
318
319
 
319
320
 
321
+ def add_parser_options(ap: ArgumentParser) -> None:
322
+ ap.add_argument(
323
+ "config",
324
+ type=str,
325
+ help=("ERT config file to start the server from "),
326
+ nargs="?", # optional
327
+ )
328
+ ap.add_argument(
329
+ "--project",
330
+ "-p",
331
+ type=Path,
332
+ help="Path to directory in which to create storage_server.json",
333
+ default=Path.cwd(),
334
+ )
335
+ ap.add_argument(
336
+ "--traceparent",
337
+ type=str,
338
+ help="Trace parent id to be used by the storage root span",
339
+ default=None,
340
+ )
341
+ ap.add_argument(
342
+ "--parent_pid",
343
+ type=int,
344
+ help="The parent process id",
345
+ default=os.getppid(),
346
+ )
347
+ ap.add_argument(
348
+ "--host", type=str, default=os.environ.get("ERT_STORAGE_HOST", "127.0.0.1")
349
+ )
350
+ ap.add_argument("--logging-config", type=str, default=None)
351
+ ap.add_argument(
352
+ "--verbose", action="store_true", help="Show verbose output.", default=False
353
+ )
354
+ ap.add_argument("--debug", action="store_true", default=False)
355
+
356
+
320
357
  if __name__ == "__main__":
321
358
  main()