ert 20.0.0b0__py3-none-any.whl → 20.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,16 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import logging
5
- from typing import cast
6
5
 
7
6
  import numpy as np
8
7
  from pydantic import PrivateAttr
9
8
 
10
9
  from ert.config import (
11
- ParameterConfig,
12
10
  PostExperimentFixtures,
13
11
  PreExperimentFixtures,
14
- ResponseConfig,
15
12
  )
16
13
  from ert.ensemble_evaluator import EvaluatorServerConfig
17
14
  from ert.run_models.initial_ensemble_run_model import (
@@ -48,11 +45,8 @@ class EnsembleSmoother(InitialEnsembleRunModel, UpdateRunModel, EnsembleSmoother
48
45
  self.run_workflows(fixtures=PreExperimentFixtures(random_seed=self.random_seed))
49
46
 
50
47
  experiment_storage = self._storage.create_experiment(
51
- parameters=cast(list[ParameterConfig], self.parameter_configuration),
52
- observations=self.observation_dataframes(),
53
- responses=cast(list[ResponseConfig], self.response_configuration),
48
+ experiment_config=self.model_dump(mode="json"),
54
49
  name=self.experiment_name,
55
- templates=self.ert_templates,
56
50
  )
57
51
 
58
52
  prior = self._storage.create_ensemble(
@@ -17,11 +17,11 @@ from enum import IntEnum, auto
17
17
  from functools import cached_property
18
18
  from pathlib import Path
19
19
  from types import TracebackType
20
- from typing import TYPE_CHECKING, Any, Protocol, cast
20
+ from typing import TYPE_CHECKING, Annotated, Any, Protocol
21
21
 
22
22
  import numpy as np
23
23
  from numpy.typing import NDArray
24
- from pydantic import PrivateAttr, ValidationError
24
+ from pydantic import Field, PrivateAttr, TypeAdapter, ValidationError
25
25
  from ropt.enums import ExitCode as RoptExitCode
26
26
  from ropt.evaluator import EvaluatorContext, EvaluatorResult
27
27
  from ropt.results import FunctionResults, Results
@@ -36,7 +36,6 @@ from ert.config import (
36
36
  GenDataConfig,
37
37
  HookRuntime,
38
38
  KnownQueueOptionsAdapter,
39
- ParameterConfig,
40
39
  QueueConfig,
41
40
  ResponseConfig,
42
41
  SummaryConfig,
@@ -222,12 +221,24 @@ def _get_workflows(
222
221
  return res_hooks, res_workflows
223
222
 
224
223
 
224
+ EverestResponseTypes = (
225
+ EverestObjectivesConfig | EverestConstraintsConfig | SummaryConfig | GenDataConfig
226
+ )
227
+
228
+ EverestResponseTypesAdapter = TypeAdapter( # type: ignore
229
+ Annotated[
230
+ EverestResponseTypes,
231
+ Field(discriminator="type"),
232
+ ]
233
+ )
234
+
235
+
225
236
  class EverestRunModelConfig(RunModelConfig):
226
237
  optimization_output_dir: str
227
238
  simulation_dir: str
228
239
 
229
- parameter_configuration: list[ParameterConfig]
230
- response_configuration: list[ResponseConfig]
240
+ parameter_configuration: list[EverestControl]
241
+ response_configuration: list[EverestResponseTypes]
231
242
 
232
243
  input_constraints: list[InputConstraintConfig]
233
244
  optimization: OptimizationConfig
@@ -542,7 +553,7 @@ class EverestRunModel(RunModel, EverestRunModelConfig):
542
553
 
543
554
  # There will and must always be one EverestControl config for an
544
555
  # Everest optimization.
545
- return cast(list[EverestControl], controls)
556
+ return controls
546
557
 
547
558
  @cached_property
548
559
  def _transforms(self) -> EverestOptModelTransforms:
@@ -676,9 +687,7 @@ class EverestRunModel(RunModel, EverestRunModelConfig):
676
687
  self._eval_server_cfg = evaluator_server_config
677
688
 
678
689
  self._experiment = self._experiment or self._storage.create_experiment(
679
- name=self.experiment_name,
680
- parameters=self.parameter_configuration,
681
- responses=self.response_configuration,
690
+ name=self.experiment_name, experiment_config=self.model_dump(mode="json")
682
691
  )
683
692
 
684
693
  self._experiment.status = ExperimentStatus(
@@ -764,7 +773,7 @@ class EverestRunModel(RunModel, EverestRunModelConfig):
764
773
 
765
774
  def _create_optimizer(self) -> tuple[BasicOptimizer, list[float]]:
766
775
  enopt_config, initial_guesses = everest2ropt(
767
- cast(list[EverestControl], self.parameter_configuration),
776
+ self.parameter_configuration,
768
777
  self.objectives_config,
769
778
  self.input_constraints,
770
779
  self.output_constraints_config,
@@ -46,13 +46,19 @@ class ManualUpdate(UpdateRunModel, ManualUpdateConfig):
46
46
  self.set_env_key("_ERT_EXPERIMENT_ID", str(prior_experiment.id))
47
47
  self.set_env_key("_ERT_ENSEMBLE_ID", str(self._prior.id))
48
48
 
49
+ experiment_config = self.model_dump(mode="json") | {
50
+ "parameter_configuration": prior_experiment.experiment_config[
51
+ "parameter_configuration"
52
+ ],
53
+ "response_configuration": prior_experiment.experiment_config[
54
+ "response_configuration"
55
+ ],
56
+ "observations": prior_experiment.experiment_config["observations"],
57
+ }
58
+
49
59
  target_experiment = self._storage.create_experiment(
50
- parameters=list(prior_experiment.parameter_configuration.values()),
51
- responses=list(prior_experiment.response_configuration.values()),
52
- observations=prior_experiment.observations,
53
- simulation_arguments=prior_experiment.metadata,
60
+ experiment_config=experiment_config,
54
61
  name=f"Manual update of {self._prior.name}",
55
- templates=self.ert_templates,
56
62
  )
57
63
  self.update(
58
64
  self._prior,
@@ -2,16 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import logging
5
- from typing import Any, ClassVar, cast
5
+ from typing import Any, ClassVar
6
6
  from uuid import UUID
7
7
 
8
8
  from pydantic import PrivateAttr
9
9
 
10
10
  from ert.config import (
11
- ParameterConfig,
12
11
  PostExperimentFixtures,
13
12
  PreExperimentFixtures,
14
- ResponseConfig,
15
13
  )
16
14
  from ert.ensemble_evaluator import EvaluatorServerConfig
17
15
  from ert.run_models.initial_ensemble_run_model import (
@@ -97,12 +95,8 @@ class MultipleDataAssimilation(
97
95
  f"restart iteration = {prior.iteration + 1}"
98
96
  )
99
97
  target_experiment = self._storage.create_experiment(
100
- parameters=list(prior.experiment.parameter_configuration.values()),
101
- responses=list(prior.experiment.response_configuration.values()),
102
- observations=prior.experiment.observations,
103
- simulation_arguments=prior.experiment.metadata,
98
+ experiment_config=self.model_dump(mode="json"),
104
99
  name=f"Restart from {prior.name}",
105
- templates=self.ert_templates,
106
100
  )
107
101
 
108
102
  except (KeyError, ValueError) as err:
@@ -113,14 +107,9 @@ class MultipleDataAssimilation(
113
107
  self.run_workflows(
114
108
  fixtures=PreExperimentFixtures(random_seed=self.random_seed),
115
109
  )
116
- sim_args = {"weights": self.weights}
117
110
  experiment_storage = self._storage.create_experiment(
118
- parameters=cast(list[ParameterConfig], self.parameter_configuration),
119
- observations=self.observation_dataframes(),
120
- responses=cast(list[ResponseConfig], self.response_configuration),
111
+ experiment_config=self.model_dump(mode="json"),
121
112
  name=self.experiment_name,
122
- templates=self.ert_templates,
123
- simulation_arguments=sim_args,
124
113
  )
125
114
 
126
115
  prior = self._storage.create_ensemble(
ert/scheduler/job.py CHANGED
@@ -482,6 +482,7 @@ async def log_warnings_from_forward_model(
482
482
  real: Realization,
483
483
  job_submission_time: float,
484
484
  timeout_seconds: int = Job.DEFAULT_FILE_VERIFICATION_TIMEOUT,
485
+ max_logged_warnings: int = 200,
485
486
  ) -> int:
486
487
  """Parse all stdout and stderr files from running the forward model
487
488
  for anything that looks like a Warning, and log it.
@@ -513,11 +514,19 @@ async def log_warnings_from_forward_model(
513
514
  )
514
515
 
515
516
  async def log_warnings_from_file(
516
- file: Path, iens: int, step: ForwardModelStep, step_idx: int, filetype: str
517
- ) -> None:
517
+ file: Path,
518
+ iens: int,
519
+ step: ForwardModelStep,
520
+ step_idx: int,
521
+ filetype: str,
522
+ max_warnings_to_log: int,
523
+ ) -> int:
524
+ """Returns how many times a warning was logged"""
518
525
  captured: list[str] = []
519
526
  file_text = await anyio.Path(file).read_text(encoding="utf-8")
520
527
  for line in file_text.splitlines():
528
+ if len(captured) >= max_warnings_to_log:
529
+ break
521
530
  if line_contains_warning(line):
522
531
  captured.append(line[:max_length])
523
532
 
@@ -528,6 +537,7 @@ async def log_warnings_from_forward_model(
528
537
  )
529
538
  warnings.warn(warning_msg, PostSimulationWarning, stacklevel=2)
530
539
  logger.warning(warning_msg)
540
+ return len(captured)
531
541
 
532
542
  async def wait_for_file(file_path: Path, _timeout: int) -> int:
533
543
  if _timeout <= 0:
@@ -546,6 +556,7 @@ async def log_warnings_from_forward_model(
546
556
  break
547
557
  return remaining_timeout
548
558
 
559
+ log_count = 0
549
560
  with suppress(KeyError):
550
561
  runpath = Path(real.run_arg.runpath)
551
562
  for step_idx, step in enumerate(real.fm_steps):
@@ -560,9 +571,18 @@ async def log_warnings_from_forward_model(
560
571
  if timeout_seconds <= 0:
561
572
  break
562
573
 
563
- await log_warnings_from_file(
564
- std_path, real.iens, step, step_idx, file_type
574
+ log_count += await log_warnings_from_file(
575
+ std_path,
576
+ real.iens,
577
+ step,
578
+ step_idx,
579
+ file_type,
580
+ max_logged_warnings - log_count,
565
581
  )
566
582
  if timeout_seconds <= 0:
567
583
  break
584
+ if log_count >= max_logged_warnings:
585
+ logger.warning(
586
+ "Reached maximum number of forward model step warnings to extract"
587
+ )
568
588
  return timeout_seconds
ert/shared/net_utils.py CHANGED
@@ -1,10 +1,8 @@
1
- import ipaddress
2
1
  import logging
3
2
  import random
4
3
  import socket
5
4
  from functools import lru_cache
6
5
 
7
- import psutil
8
6
  from dns import exception, resolver, reversename
9
7
 
10
8
 
@@ -52,7 +50,6 @@ def get_machine_name() -> str:
52
50
  def find_available_socket(
53
51
  host: str | None = None,
54
52
  port_range: range = range(51820, 51840 + 1),
55
- prioritize_private_ip_address: bool = False,
56
53
  ) -> socket.socket:
57
54
  """
58
55
  The default and recommended approach here is to return a bound socket to the
@@ -74,9 +71,7 @@ def find_available_socket(
74
71
 
75
72
  See e.g. implementation and comments in EvaluatorServerConfig
76
73
  """
77
- current_host = (
78
- host if host is not None else get_ip_address(prioritize_private_ip_address)
79
- )
74
+ current_host = host if host is not None else get_ip_address()
80
75
 
81
76
  if port_range.start == port_range.stop:
82
77
  ports = list(range(port_range.start, port_range.stop + 1))
@@ -140,40 +135,20 @@ def get_family(host: str) -> socket.AddressFamily:
140
135
  return socket.AF_INET6
141
136
 
142
137
 
143
- def get_ip_address(prioritize_private: bool = False) -> str:
144
- """
145
- Get the first (private or public) IPv4 address of the current machine on the LAN.
146
- Default behaviour returns the first public IP if found, then private, then loopback.
147
-
148
- Parameters:
149
- prioritize_private (bool): If True, private IP addresses are prioritized
150
-
151
- Returns:
152
- str: The selected IP address as a string.
153
- """
154
- loopback = ""
155
- public = ""
156
- private = ""
157
- interfaces = psutil.net_if_addrs()
158
- for addresses in interfaces.values():
159
- for address in addresses:
160
- if address.family.name == "AF_INET":
161
- ip = address.address
162
- if ipaddress.ip_address(ip).is_loopback and not loopback:
163
- loopback = ip
164
- elif ipaddress.ip_address(ip).is_private and not private:
165
- private = ip
166
- elif not public:
167
- public = ip
168
-
169
- # Select first non-empty value, based on prioritization
170
- if prioritize_private:
171
- selected_ip = private or public or loopback
172
- else:
173
- selected_ip = public or private or loopback
174
-
175
- if selected_ip:
176
- return selected_ip
177
- else:
178
- logger.warning("Cannot determine ip-address. Falling back to 127.0.0.1")
179
- return "127.0.0.1"
138
+ # See https://stackoverflow.com/a/28950776
139
+ def get_ip_address() -> str:
140
+ try:
141
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
142
+ try:
143
+ s.settimeout(0)
144
+ # try pinging a reserved, internal address in order
145
+ # to determine IP representing the default route
146
+ s.connect(("10.255.255.255", 1))
147
+ address = s.getsockname()[0]
148
+ finally:
149
+ s.close()
150
+ except BaseException:
151
+ logger.warning("Cannot determine ip-address. Falling back to localhost.")
152
+ address = "127.0.0.1"
153
+ logger.debug(f"ip-address: {address}")
154
+ return address
ert/shared/version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '20.0.0b0'
32
- __version_tuple__ = version_tuple = (20, 0, 0, 'b0')
31
+ __version__ = version = '20.0.0b1'
32
+ __version_tuple__ = version_tuple = (20, 0, 0, 'b1')
33
33
 
34
- __commit_id__ = commit_id = 'g23469e640'
34
+ __commit_id__ = commit_id = 'g4156859a8'
@@ -1077,7 +1077,9 @@ class LocalEnsemble(BaseMode):
1077
1077
  for col, observed_values in observed_cols.items():
1078
1078
  if col != "time":
1079
1079
  responses = responses.filter(
1080
- pl.col(col).is_in(observed_values.implode())
1080
+ pl.col(col).is_in(
1081
+ observed_values.implode(), nulls_equal=True
1082
+ )
1081
1083
  )
1082
1084
 
1083
1085
  pivoted = responses.collect(engine="streaming").pivot(
@@ -1120,6 +1122,7 @@ class LocalEnsemble(BaseMode):
1120
1122
  pivoted,
1121
1123
  how="left",
1122
1124
  on=["response_key", *response_cls.primary_key],
1125
+ nulls_equal=True,
1123
1126
  )
1124
1127
 
1125
1128
  # Do not drop primary keys which
@@ -24,7 +24,10 @@ from ert.config import (
24
24
  SurfaceConfig,
25
25
  )
26
26
  from ert.config import Field as FieldConfig
27
- from ert.config.parsing.context_values import ContextBoolEncoder
27
+ from ert.config._create_observation_dataframes import (
28
+ create_observation_dataframes,
29
+ )
30
+ from ert.config._observations import Observation
28
31
 
29
32
  from .mode import BaseMode, Mode, require_write
30
33
 
@@ -54,6 +57,7 @@ class _Index(BaseModel):
54
57
  # from a different experiment. For example, a manual update
55
58
  # is a separate experiment from the one that created the prior.
56
59
  ensembles: list[UUID]
60
+ experiment: dict[str, Any] = {}
57
61
  status: ExperimentStatus | None = Field(default=None)
58
62
 
59
63
 
@@ -80,9 +84,6 @@ class LocalExperiment(BaseMode):
80
84
  arguments. Provides methods to create and access associated ensembles.
81
85
  """
82
86
 
83
- _parameter_file = Path("parameter.json")
84
- _responses_file = Path("responses.json")
85
- _metadata_file = Path("metadata.json")
86
87
  _templates_file = Path("templates.json")
87
88
  _index_file = Path("index.json")
88
89
 
@@ -118,63 +119,21 @@ class LocalExperiment(BaseMode):
118
119
  storage: LocalStorage,
119
120
  uuid: UUID,
120
121
  path: Path,
121
- *,
122
- parameters: list[ParameterConfig] | None = None,
123
- responses: list[ResponseConfig] | None = None,
124
- observations: dict[str, pl.DataFrame] | None = None,
125
- simulation_arguments: dict[Any, Any] | None = None,
122
+ experiment_config: dict[str, Any],
126
123
  name: str | None = None,
127
- templates: list[tuple[str, str]] | None = None,
128
124
  ) -> LocalExperiment:
129
- """
130
- Create a new LocalExperiment and store its configuration data.
131
-
132
- Parameters
133
- ----------
134
- storage : LocalStorage
135
- Storage instance for experiment creation.
136
- uuid : UUID
137
- Unique identifier for the new experiment.
138
- path : Path
139
- File system path for storing experiment data.
140
- parameters : list of ParameterConfig, optional
141
- List of parameter configurations.
142
- responses : list of ResponseConfig, optional
143
- List of response configurations.
144
- observations : dict of str to encoded observation datasets, optional
145
- Observations dictionary.
146
- simulation_arguments : SimulationArguments, optional
147
- Simulation arguments for the experiment.
148
- name : str, optional
149
- Experiment name. Defaults to current date if None.
150
- templates : list of tuple[str, str], optional
151
- Run templates for the experiment. Defaults to None.
152
-
153
- Returns
154
- -------
155
- local_experiment : LocalExperiment
156
- Instance of the newly created experiment.
157
- """
158
125
  if name is None:
159
126
  name = datetime.today().isoformat()
160
127
 
161
128
  storage._write_transaction(
162
129
  path / cls._index_file,
163
- _Index(id=uuid, name=name, ensembles=[])
130
+ _Index(id=uuid, name=name, ensembles=[], experiment=experiment_config)
164
131
  .model_dump_json(indent=2, exclude_none=True)
165
132
  .encode("utf-8"),
166
133
  )
167
134
 
168
- parameter_data = {}
169
- for parameter in parameters or []:
170
- parameter.save_experiment_data(path)
171
- parameter_data.update({parameter.name: parameter.model_dump(mode="json")})
172
- storage._write_transaction(
173
- path / cls._parameter_file,
174
- json.dumps(parameter_data, indent=2).encode("utf-8"),
175
- )
176
-
177
- if templates:
135
+ templates = experiment_config.get("ert_templates")
136
+ if templates is not None:
178
137
  templates_path = path / "templates"
179
138
  templates_path.mkdir(parents=True, exist_ok=True)
180
139
  templates_abs: list[tuple[str, str]] = []
@@ -191,27 +150,29 @@ class LocalExperiment(BaseMode):
191
150
  json.dumps(templates_abs).encode("utf-8"),
192
151
  )
193
152
 
194
- response_data = {}
195
- for response in responses or []:
196
- response_data.update({response.type: response.model_dump(mode="json")})
197
- storage._write_transaction(
198
- path / cls._responses_file,
199
- json.dumps(response_data, default=str, indent=2).encode("utf-8"),
200
- )
201
-
202
- if observations:
153
+ observation_declarations = experiment_config.get("observations")
154
+ if observation_declarations:
203
155
  output_path = path / "observations"
204
- output_path.mkdir()
205
- for response_type, dataset in observations.items():
206
- storage._to_parquet_transaction(
207
- output_path / f"{response_type}", dataset
208
- )
156
+ output_path.mkdir(parents=True, exist_ok=True)
209
157
 
210
- simulation_data = simulation_arguments or {}
211
- storage._write_transaction(
212
- path / cls._metadata_file,
213
- json.dumps(simulation_data, cls=ContextBoolEncoder).encode("utf-8"),
214
- )
158
+ responses_list = experiment_config.get("response_configuration", [])
159
+ rft_config_json = next(
160
+ (r for r in responses_list if r.get("type") == "rft"), None
161
+ )
162
+ rft_config = (
163
+ _responses_adapter.validate_python(rft_config_json)
164
+ if rft_config_json is not None
165
+ else None
166
+ )
167
+
168
+ obs_adapter = TypeAdapter(Observation) # type: ignore
169
+ obs_objs: list[Observation] = []
170
+ for od in observation_declarations:
171
+ obs_objs.append(obs_adapter.validate_python(od))
172
+
173
+ datasets = create_observation_dataframes(obs_objs, rft_config)
174
+ for response_type, df in datasets.items():
175
+ storage._to_parquet_transaction(output_path / response_type, df)
215
176
 
216
177
  return cls(storage, path, Mode.WRITE)
217
178
 
@@ -285,16 +246,10 @@ class LocalExperiment(BaseMode):
285
246
  return ens
286
247
  raise KeyError(f"Ensemble with name '{name}' not found")
287
248
 
288
- @property
289
- def metadata(self) -> dict[str, Any]:
290
- path = self.mount_point / self._metadata_file
291
- if not path.exists():
292
- raise ValueError(f"{self._metadata_file!s} does not exist")
293
- return json.loads(path.read_text(encoding="utf-8"))
294
-
295
249
  @property
296
250
  def relative_weights(self) -> str:
297
- return self.metadata.get("weights", "")
251
+ assert self.experiment_config is not None
252
+ return self.experiment_config.get("weights", "")
298
253
 
299
254
  @property
300
255
  def name(self) -> str:
@@ -324,9 +279,8 @@ class LocalExperiment(BaseMode):
324
279
 
325
280
  @property
326
281
  def parameter_info(self) -> dict[str, Any]:
327
- return json.loads(
328
- (self.mount_point / self._parameter_file).read_text(encoding="utf-8")
329
- )
282
+ parameters_list = self.experiment_config.get("parameter_configuration", [])
283
+ return {parameter["name"]: parameter for parameter in parameters_list}
330
284
 
331
285
  @property
332
286
  def templates_configuration(self) -> list[tuple[str, str]]:
@@ -348,9 +302,8 @@ class LocalExperiment(BaseMode):
348
302
 
349
303
  @property
350
304
  def response_info(self) -> dict[str, Any]:
351
- return json.loads(
352
- (self.mount_point / self._responses_file).read_text(encoding="utf-8")
353
- )
305
+ responses_list = self.experiment_config.get("response_configuration", [])
306
+ return {response["type"]: response for response in responses_list}
354
307
 
355
308
  def get_surface(self, name: str) -> IrapSurface:
356
309
  """
@@ -420,10 +373,50 @@ class LocalExperiment(BaseMode):
420
373
 
421
374
  @cached_property
422
375
  def observations(self) -> dict[str, pl.DataFrame]:
423
- observations = sorted(self.mount_point.glob("observations/*"))
376
+ obs_dir = self.mount_point / "observations"
377
+
378
+ if obs_dir.exists():
379
+ datasets: dict[str, pl.DataFrame] = {}
380
+ for p in obs_dir.iterdir():
381
+ if not p.is_file():
382
+ continue
383
+ try:
384
+ df = pl.read_parquet(p)
385
+ except Exception:
386
+ continue
387
+ datasets[p.stem] = df
388
+ return datasets
389
+
390
+ serialized_observations = self.experiment_config.get("observations", None)
391
+ if not serialized_observations:
392
+ return {}
393
+
394
+ output_path = self.mount_point / "observations"
395
+ output_path.mkdir(parents=True, exist_ok=True)
396
+
397
+ rft_cfg = None
398
+ try:
399
+ responses_list = self.experiment_config.get("response_configuration", [])
400
+ for r in responses_list:
401
+ if r.get("type") == "rft":
402
+ rft_cfg = _responses_adapter.validate_python(r)
403
+ break
404
+ except Exception:
405
+ rft_cfg = None
406
+
407
+ obs_adapter = TypeAdapter(Observation) # type: ignore
408
+ obs_objs: list[Observation] = []
409
+ for od in serialized_observations:
410
+ obs_objs.append(obs_adapter.validate_python(od))
411
+
412
+ datasets = create_observation_dataframes(obs_objs, rft_cfg)
413
+ for response_type, df in datasets.items():
414
+ self._storage._to_parquet_transaction(output_path / response_type, df)
415
+
424
416
  return {
425
- observation.name: pl.read_parquet(f"{observation}")
426
- for observation in observations
417
+ p.stem: pl.read_parquet(p)
418
+ for p in (self.mount_point / "observations").iterdir()
419
+ if p.is_file()
427
420
  }
428
421
 
429
422
  @cached_property
@@ -489,18 +482,22 @@ class LocalExperiment(BaseMode):
489
482
  )
490
483
 
491
484
  config = responses_configuration[response_type]
485
+
492
486
  config.keys = sorted(response_keys)
493
487
  config.has_finalized_keys = True
488
+
489
+ response_index = next(
490
+ i
491
+ for i, c in enumerate(self.experiment_config["response_configuration"])
492
+ if c["type"] == response_type
493
+ )
494
+ self.experiment_config["response_configuration"][response_index] = (
495
+ config.model_dump(mode="json")
496
+ )
497
+
494
498
  self._storage._write_transaction(
495
- self._path / self._responses_file,
496
- json.dumps(
497
- {
498
- c.type: c.model_dump(mode="json")
499
- for c in responses_configuration.values()
500
- },
501
- default=str,
502
- indent=2,
503
- ).encode("utf-8"),
499
+ self._path / self._index_file,
500
+ self._index.model_dump_json(indent=2).encode("utf-8"),
504
501
  )
505
502
 
506
503
  if self.response_key_to_response_type is not None:
@@ -508,3 +505,7 @@ class LocalExperiment(BaseMode):
508
505
 
509
506
  if self.response_type_to_response_keys is not None:
510
507
  del self.response_type_to_response_keys
508
+
509
+ @property
510
+ def experiment_config(self) -> dict[str, Any]:
511
+ return self._index.experiment