pyetp 0.0.38__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyetp/client.py CHANGED
@@ -4,52 +4,176 @@ import logging
4
4
  import sys
5
5
  import typing as T
6
6
  import uuid
7
+ import warnings
7
8
  from collections import defaultdict
8
9
  from types import TracebackType
9
- import time
10
+
10
11
  import numpy as np
11
12
  import websockets
12
- from etpproto.connection import (CommunicationProtocol, ConnectionType,
13
- ETPConnection)
13
+ from etpproto.connection import CommunicationProtocol, ConnectionType, ETPConnection
14
14
  from etpproto.messages import Message, MessageFlags
15
+ from etptypes import ETPModel
16
+ from etptypes.energistics.etp.v12.datatypes.any_array_type import AnyArrayType
17
+ from etptypes.energistics.etp.v12.datatypes.any_logical_array_type import (
18
+ AnyLogicalArrayType,
19
+ )
20
+ from etptypes.energistics.etp.v12.datatypes.array_of_string import ArrayOfString
21
+ from etptypes.energistics.etp.v12.datatypes.data_array_types.data_array_identifier import (
22
+ DataArrayIdentifier,
23
+ )
24
+ from etptypes.energistics.etp.v12.datatypes.data_array_types.data_array_metadata import (
25
+ DataArrayMetadata,
26
+ )
27
+ from etptypes.energistics.etp.v12.datatypes.data_array_types.get_data_subarrays_type import (
28
+ GetDataSubarraysType,
29
+ )
30
+ from etptypes.energistics.etp.v12.datatypes.data_array_types.put_data_arrays_type import (
31
+ PutDataArraysType,
32
+ )
33
+ from etptypes.energistics.etp.v12.datatypes.data_array_types.put_data_subarrays_type import (
34
+ PutDataSubarraysType,
35
+ )
36
+ from etptypes.energistics.etp.v12.datatypes.data_array_types.put_uninitialized_data_array_type import (
37
+ PutUninitializedDataArrayType,
38
+ )
39
+ from etptypes.energistics.etp.v12.datatypes.data_value import DataValue
40
+ from etptypes.energistics.etp.v12.datatypes.error_info import ErrorInfo
41
+ from etptypes.energistics.etp.v12.datatypes.object.context_info import ContextInfo
42
+ from etptypes.energistics.etp.v12.datatypes.object.context_scope_kind import (
43
+ ContextScopeKind,
44
+ )
45
+ from etptypes.energistics.etp.v12.datatypes.object.data_object import DataObject
46
+ from etptypes.energistics.etp.v12.datatypes.object.dataspace import Dataspace
47
+ from etptypes.energistics.etp.v12.datatypes.object.relationship_kind import (
48
+ RelationshipKind,
49
+ )
50
+ from etptypes.energistics.etp.v12.datatypes.object.resource import Resource
51
+ from etptypes.energistics.etp.v12.datatypes.supported_data_object import (
52
+ SupportedDataObject,
53
+ )
54
+ from etptypes.energistics.etp.v12.datatypes.supported_protocol import SupportedProtocol
55
+ from etptypes.energistics.etp.v12.datatypes.uuid import Uuid
56
+ from etptypes.energistics.etp.v12.datatypes.version import Version
57
+ from etptypes.energistics.etp.v12.protocol.core.authorize import Authorize
58
+ from etptypes.energistics.etp.v12.protocol.core.authorize_response import (
59
+ AuthorizeResponse,
60
+ )
61
+ from etptypes.energistics.etp.v12.protocol.core.close_session import CloseSession
62
+ from etptypes.energistics.etp.v12.protocol.core.open_session import OpenSession
63
+ from etptypes.energistics.etp.v12.protocol.core.protocol_exception import (
64
+ ProtocolException,
65
+ )
66
+ from etptypes.energistics.etp.v12.protocol.core.request_session import RequestSession
67
+ from etptypes.energistics.etp.v12.protocol.data_array.get_data_array_metadata import (
68
+ GetDataArrayMetadata,
69
+ )
70
+ from etptypes.energistics.etp.v12.protocol.data_array.get_data_array_metadata_response import (
71
+ GetDataArrayMetadataResponse,
72
+ )
73
+ from etptypes.energistics.etp.v12.protocol.data_array.get_data_arrays import (
74
+ GetDataArrays,
75
+ )
76
+ from etptypes.energistics.etp.v12.protocol.data_array.get_data_arrays_response import (
77
+ GetDataArraysResponse,
78
+ )
79
+ from etptypes.energistics.etp.v12.protocol.data_array.get_data_subarrays import (
80
+ GetDataSubarrays,
81
+ )
82
+ from etptypes.energistics.etp.v12.protocol.data_array.get_data_subarrays_response import (
83
+ GetDataSubarraysResponse,
84
+ )
85
+ from etptypes.energistics.etp.v12.protocol.data_array.put_data_arrays import (
86
+ PutDataArrays,
87
+ )
88
+ from etptypes.energistics.etp.v12.protocol.data_array.put_data_arrays_response import (
89
+ PutDataArraysResponse,
90
+ )
91
+ from etptypes.energistics.etp.v12.protocol.data_array.put_data_subarrays import (
92
+ PutDataSubarrays,
93
+ )
94
+ from etptypes.energistics.etp.v12.protocol.data_array.put_data_subarrays_response import (
95
+ PutDataSubarraysResponse,
96
+ )
97
+ from etptypes.energistics.etp.v12.protocol.data_array.put_uninitialized_data_arrays import (
98
+ PutUninitializedDataArrays,
99
+ )
100
+ from etptypes.energistics.etp.v12.protocol.data_array.put_uninitialized_data_arrays_response import (
101
+ PutUninitializedDataArraysResponse,
102
+ )
103
+ from etptypes.energistics.etp.v12.protocol.dataspace.delete_dataspaces import (
104
+ DeleteDataspaces,
105
+ )
106
+ from etptypes.energistics.etp.v12.protocol.dataspace.delete_dataspaces_response import (
107
+ DeleteDataspacesResponse,
108
+ )
109
+ from etptypes.energistics.etp.v12.protocol.dataspace.get_dataspaces import GetDataspaces
110
+ from etptypes.energistics.etp.v12.protocol.dataspace.get_dataspaces_response import (
111
+ GetDataspacesResponse,
112
+ )
113
+ from etptypes.energistics.etp.v12.protocol.dataspace.put_dataspaces import PutDataspaces
114
+ from etptypes.energistics.etp.v12.protocol.dataspace.put_dataspaces_response import (
115
+ PutDataspacesResponse,
116
+ )
117
+ from etptypes.energistics.etp.v12.protocol.discovery.get_resources import GetResources
118
+ from etptypes.energistics.etp.v12.protocol.store.delete_data_objects import (
119
+ DeleteDataObjects,
120
+ )
121
+ from etptypes.energistics.etp.v12.protocol.store.delete_data_objects_response import (
122
+ DeleteDataObjectsResponse,
123
+ )
124
+ from etptypes.energistics.etp.v12.protocol.store.get_data_objects import GetDataObjects
125
+ from etptypes.energistics.etp.v12.protocol.store.get_data_objects_response import (
126
+ GetDataObjectsResponse,
127
+ )
128
+ from etptypes.energistics.etp.v12.protocol.store.put_data_objects import PutDataObjects
129
+ from etptypes.energistics.etp.v12.protocol.store.put_data_objects_response import (
130
+ PutDataObjectsResponse,
131
+ )
132
+ from etptypes.energistics.etp.v12.protocol.transaction.commit_transaction import (
133
+ CommitTransaction,
134
+ )
135
+ from etptypes.energistics.etp.v12.protocol.transaction.rollback_transaction import (
136
+ RollbackTransaction,
137
+ )
138
+ from etptypes.energistics.etp.v12.protocol.transaction.start_transaction import (
139
+ StartTransaction,
140
+ )
15
141
  from pydantic import SecretStr
16
- from scipy.interpolate import griddata
17
142
  from xtgeo import RegularSurface
18
143
 
19
-
20
- import pyetp.resqml_objects as ro
21
- #import energyml.resqml.v2_0_1.resqmlv2 as ro
22
- #import energyml.eml.v2_0.commonv2 as roc
144
+ import resqml_objects.v201 as ro
23
145
  from pyetp import utils_arrays, utils_xml
24
146
  from pyetp.config import SETTINGS
25
- from pyetp.types import *
26
147
  from pyetp.uri import DataObjectURI, DataspaceURI
27
- from pyetp.utils import short_id, batched
28
- #from asyncio import timeout
148
+ from resqml_objects import parse_resqml_v201_object, serialize_resqml_v201_object
149
+
150
+ logger = logging.getLogger(__name__)
29
151
 
30
152
  try:
31
153
  # for py >3.11, we can raise grouped exceptions
32
154
  from builtins import ExceptionGroup # type: ignore
33
155
  except ImportError:
156
+ # Python 3.10
34
157
  def ExceptionGroup(msg, errors):
35
158
  return errors[0]
36
159
 
160
+
37
161
  try:
162
+ # Python >= 3.11
38
163
  from asyncio import timeout
39
164
  except ImportError:
40
- import async_timeout
165
+ # Python 3.10
41
166
  from contextlib import asynccontextmanager
167
+
168
+ import async_timeout
169
+
42
170
  @asynccontextmanager
43
171
  async def timeout(delay: T.Optional[float]) -> T.Any:
44
172
  try:
45
173
  async with async_timeout.timeout(delay):
46
174
  yield None
47
175
  except asyncio.CancelledError as e:
48
- raise asyncio.TimeoutError(f'Timeout ({delay}s)') from e
49
-
50
-
51
- logger = logging.getLogger(__name__)
52
- logger.setLevel(logging.INFO)
176
+ raise asyncio.TimeoutError(f"Timeout ({delay}s)") from e
53
177
 
54
178
 
55
179
  class ETPError(Exception):
@@ -72,7 +196,8 @@ def get_all_etp_protocol_classes():
72
196
  """Update protocol - all exception protocols are now per message"""
73
197
 
74
198
  pddict = ETPConnection.generic_transition_table
75
- pexec = ETPConnection.generic_transition_table["0"]["1000"] # protocol exception
199
+ # protocol exception
200
+ pexec = ETPConnection.generic_transition_table["0"]["1000"]
76
201
 
77
202
  for v in pddict.values():
78
203
  v["1000"] = pexec
@@ -81,20 +206,21 @@ def get_all_etp_protocol_classes():
81
206
 
82
207
 
83
208
  class ETPClient(ETPConnection):
84
-
85
209
  generic_transition_table = get_all_etp_protocol_classes()
86
210
 
87
211
  _recv_events: T.Dict[int, asyncio.Event]
88
212
  _recv_buffer: T.Dict[int, T.List[ETPModel]]
89
213
 
90
- def __init__(self, ws: websockets.WebSocketClientProtocol, timeout=10.):
214
+ def __init__(self, ws: websockets.ClientConnection, timeout=10.0):
91
215
  super().__init__(connection_type=ConnectionType.CLIENT)
92
216
  self._recv_events = {}
93
217
  self._recv_buffer = defaultdict(lambda: list()) # type: ignore
94
218
  self.ws = ws
95
219
 
96
220
  self.timeout = timeout
97
- self.client_info.endpoint_capabilities['MaxWebSocketMessagePayloadSize'] = SETTINGS.MaxWebSocketMessagePayloadSize
221
+ self.client_info.endpoint_capabilities["MaxWebSocketMessagePayloadSize"] = (
222
+ SETTINGS.MaxWebSocketMessagePayloadSize
223
+ )
98
224
  self.__recvtask = asyncio.create_task(self.__recv__())
99
225
 
100
226
  #
@@ -106,11 +232,8 @@ class ETPClient(ETPConnection):
106
232
  return await self._recv(correlation_id)
107
233
 
108
234
  async def _send(self, body: ETPModel):
109
-
110
- msg = Message.get_object_message(
111
- body, message_flags=MessageFlags.FINALPART
112
- )
113
- if msg == None:
235
+ msg = Message.get_object_message(body, message_flags=MessageFlags.FINALPART)
236
+ if msg is None:
114
237
  raise TypeError(f"{type(body)} not valid etp protocol")
115
238
 
116
239
  msg.header.message_id = self.consume_msg_id()
@@ -125,7 +248,9 @@ class ETPClient(ETPConnection):
125
248
  return msg.header.message_id
126
249
 
127
250
  async def _recv(self, correlation_id: int) -> ETPModel:
128
- assert correlation_id in self._recv_events, "trying to recv response on non-existing message"
251
+ assert correlation_id in self._recv_events, (
252
+ "trying to recv response on non-existing message"
253
+ )
129
254
 
130
255
  async with timeout(self.timeout):
131
256
  await self._recv_events[correlation_id].wait()
@@ -139,7 +264,9 @@ class ETPClient(ETPConnection):
139
264
  if len(errors) == 1:
140
265
  raise ETPError.from_proto(errors.pop())
141
266
  elif len(errors) > 1:
142
- raise ExceptionGroup("Server responded with ETPErrors:", ETPError.from_protos(errors))
267
+ raise ExceptionGroup(
268
+ "Server responded with ETPErrors:", ETPError.from_protos(errors)
269
+ )
143
270
 
144
271
  if len(bodies) > 1:
145
272
  logger.warning(f"Recived {len(bodies)} messages, but only expected one")
@@ -147,7 +274,6 @@ class ETPClient(ETPConnection):
147
274
  # ok
148
275
  return bodies[0]
149
276
 
150
-
151
277
  @staticmethod
152
278
  def _parse_error_info(bodies: list[ETPModel]) -> list[ErrorInfo]:
153
279
  # returns all error infos from bodies
@@ -159,22 +285,59 @@ class ETPClient(ETPConnection):
159
285
  errors.extend(body.errors.values())
160
286
  return errors
161
287
 
162
- async def close(self, reason=''):
163
- if self.ws.closed:
164
- self.__recvtask.cancel("stopped")
165
- # fast exit if already closed
166
- return
167
-
288
+ async def close(self, reason=""):
168
289
  try:
169
290
  await self._send(CloseSession(reason=reason))
291
+ except websockets.ConnectionClosed:
292
+ logger.error(
293
+ "Websockets connection is closed, unable to send a CloseSession-message"
294
+ " to the server"
295
+ )
170
296
  finally:
171
- await self.ws.close(reason=reason)
297
+ # Check if the receive task is done, and if not, stop it.
298
+ if not self.__recvtask.done():
299
+ self.__recvtask.cancel("stopped")
300
+
172
301
  self.is_connected = False
173
- self.__recvtask.cancel("stopped")
174
302
 
175
- if len(self._recv_buffer):
176
- logger.error(f"Closed connection - but had stuff left in buffers ({len(self._recv_buffer)})")
177
- # logger.warning(self._recv_buffer) # may contain data so lets not flood logs
303
+ try:
304
+ # Raise any potential exceptions that might have occured in the
305
+ # receive task
306
+ await self.__recvtask
307
+ except asyncio.CancelledError:
308
+ # No errors except for a cancellation, which is to be expected.
309
+ pass
310
+ except websockets.ConnectionClosed as e:
311
+ # The receive task errored on a closed websockets connection.
312
+ logger.error(
313
+ "The receiver task errored on a closed websockets connection. The "
314
+ f"message was: {e.__class__.__name__}: {e}"
315
+ )
316
+
317
+ if len(self._recv_buffer) > 0:
318
+ logger.error(
319
+ f"Connection is closed, but there are {len(self._recv_buffer)} "
320
+ "messages left in the buffer"
321
+ )
322
+
323
+ # Check if there were any messages left in the websockets connection.
324
+ # Reading them will speed up the closing of the connection.
325
+ counter = 0
326
+ try:
327
+ async for msg in self.ws:
328
+ counter += 1
329
+ except websockets.ConnectionClosed:
330
+ # The websockets connection had already closed. Either successfully
331
+ # or with an error, but we ignore both cases.
332
+ pass
333
+
334
+ if counter > 0:
335
+ logger.error(
336
+ f"There were {counter} unread messages in the websockets connection "
337
+ "after the session was closed"
338
+ )
339
+
340
+ logger.debug("Client closed")
178
341
 
179
342
  #
180
343
  #
@@ -189,13 +352,13 @@ class ETPClient(ETPConnection):
189
352
 
190
353
  # NOTE: should we add task to autoclear buffer message if never waited on ?
191
354
  if msg.is_final_msg():
192
- self._recv_events[msg.header.correlation_id].set() # set response on send event
355
+ # set response on send event
356
+ self._recv_events[msg.header.correlation_id].set()
193
357
 
194
358
  async def __recv__(self):
359
+ logger.debug("starting recv loop")
195
360
 
196
- logger.debug(f"starting recv loop")
197
-
198
- while (True):
361
+ while True:
199
362
  msg_data = await self.ws.recv()
200
363
  msg = Message.decode_binary_message(
201
364
  T.cast(bytes, msg_data), ETPClient.generic_transition_table
@@ -215,21 +378,38 @@ class ETPClient(ETPConnection):
215
378
  async def request_session(self):
216
379
  # Handshake protocol
217
380
  etp_version = Version(major=1, minor=2, revision=0, patch=0)
381
+
382
+ def get_protocol_server_role(protocol: CommunicationProtocol) -> str:
383
+ match protocol:
384
+ case CommunicationProtocol.CORE:
385
+ return "server"
386
+ case CommunicationProtocol.CHANNEL_STREAMING:
387
+ return "producer"
388
+
389
+ return "store"
390
+
218
391
  msg = await self.send(
219
392
  RequestSession(
220
393
  applicationName=SETTINGS.application_name,
221
394
  applicationVersion=SETTINGS.application_version,
222
395
  clientInstanceId=uuid.uuid4(), # type: ignore
223
396
  requestedProtocols=[
224
- SupportedProtocol(protocol=p.value, protocolVersion=etp_version, role='store')
397
+ SupportedProtocol(
398
+ protocol=p.value,
399
+ protocolVersion=etp_version,
400
+ role=get_protocol_server_role(p),
401
+ )
225
402
  for p in CommunicationProtocol
226
403
  ],
227
- supportedDataObjects=[SupportedDataObject(qualifiedType="resqml20.*"), SupportedDataObject(qualifiedType="eml20.*")],
404
+ supportedDataObjects=[
405
+ SupportedDataObject(qualifiedType="resqml20.*"),
406
+ SupportedDataObject(qualifiedType="eml20.*"),
407
+ ],
228
408
  currentDateTime=self.timestamp,
229
409
  earliestRetainedChangeTime=0,
230
410
  endpointCapabilities=dict(
231
411
  MaxWebSocketMessagePayloadSize=DataValue(item=self.max_size)
232
- )
412
+ ),
233
413
  )
234
414
  )
235
415
  assert msg and isinstance(msg, OpenSession)
@@ -237,19 +417,18 @@ class ETPClient(ETPConnection):
237
417
  self.is_connected = True
238
418
 
239
419
  # ignore this endpoint
240
- _ = msg.endpoint_capabilities.pop('MessageQueueDepth', None)
420
+ _ = msg.endpoint_capabilities.pop("MessageQueueDepth", None)
241
421
  self.client_info.negotiate(msg)
242
422
 
243
423
  return self
244
424
 
245
- async def authorize(self, authorization: str, supplemental_authorization: T.Mapping[str, str] = {}):
246
-
247
-
248
-
425
+ async def authorize(
426
+ self, authorization: str, supplemental_authorization: T.Mapping[str, str] = {}
427
+ ):
249
428
  msg = await self.send(
250
429
  Authorize(
251
430
  authorization=authorization,
252
- supplementalAuthorization=supplemental_authorization
431
+ supplementalAuthorization=supplemental_authorization,
253
432
  )
254
433
  )
255
434
  assert msg and isinstance(msg, AuthorizeResponse)
@@ -261,7 +440,7 @@ class ETPClient(ETPConnection):
261
440
  @property
262
441
  def max_size(self):
263
442
  return SETTINGS.MaxWebSocketMessagePayloadSize
264
- #return self.client_info.getCapability("MaxWebSocketMessagePayloadSize")
443
+ # return self.client_info.getCapability("MaxWebSocketMessagePayloadSize")
265
444
 
266
445
  @property
267
446
  def max_array_size(self):
@@ -271,57 +450,103 @@ class ETPClient(ETPConnection):
271
450
  def timestamp(self):
272
451
  return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
273
452
 
274
-
275
453
  def dataspace_uri(self, ds: str) -> DataspaceURI:
276
454
  if ds.count("/") > 1:
277
- raise Exception(f"Max one / in dataspace name")
455
+ raise Exception("Max one / in dataspace name")
278
456
  return DataspaceURI.from_name(ds)
279
-
457
+
280
458
  def list_objects(self, dataspace_uri: DataspaceURI, depth: int = 1) -> list:
281
- return self.send(GetResources(
459
+ return self.send(
460
+ GetResources(
282
461
  scope=ContextScopeKind.TARGETS_OR_SELF,
283
462
  context=ContextInfo(
284
463
  uri=dataspace_uri.raw_uri,
285
464
  depth=depth,
286
465
  dataObjectTypes=[],
287
- navigableEdges=RelationshipKind.PRIMARY,)
288
- )
466
+ navigableEdges=RelationshipKind.PRIMARY,
467
+ ),
289
468
  )
469
+ )
470
+
290
471
  #
291
472
  # dataspace
292
473
  #
293
474
 
294
- async def put_dataspaces(self, *dataspace_uris: DataspaceURI):
475
+ async def get_dataspaces(
476
+ self, store_last_write_filter: int = None
477
+ ) -> GetDataspacesResponse:
478
+ return await self.send(
479
+ GetDataspaces(store_last_write_filter=store_last_write_filter)
480
+ )
481
+
482
+ async def put_dataspaces(
483
+ self,
484
+ legaltags: list[str],
485
+ otherRelevantDataCountries: list[str],
486
+ owners: list[str],
487
+ viewers: list[str],
488
+ *dataspace_uris: DataspaceURI,
489
+ ):
295
490
  _uris = list(map(DataspaceURI.from_any, dataspace_uris))
296
491
  for i in _uris:
297
- if i.raw_uri.count("/") > 4: # includes the 3 eml
298
- raise Exception(f"Max one / in dataspace name")
492
+ if i.raw_uri.count("/") > 4: # includes the 3 eml
493
+ raise Exception("Max one / in dataspace name")
299
494
  time = self.timestamp
300
495
  response = await self.send(
301
- PutDataspaces(dataspaces={
302
- d.raw_uri: Dataspace(uri=d.raw_uri, storeCreated=time, storeLastWrite=time, path=d.dataspace)
303
- for d in _uris
304
- })
496
+ PutDataspaces(
497
+ dataspaces={
498
+ d.raw_uri: Dataspace(
499
+ uri=d.raw_uri,
500
+ storeCreated=time,
501
+ storeLastWrite=time,
502
+ path=d.dataspace,
503
+ custom_data={
504
+ "legaltags": DataValue(
505
+ item=ArrayOfString(values=legaltags)
506
+ ),
507
+ "otherRelevantDataCountries": DataValue(
508
+ item=ArrayOfString(values=otherRelevantDataCountries)
509
+ ),
510
+ "owners": DataValue(item=ArrayOfString(values=owners)),
511
+ "viewers": DataValue(item=ArrayOfString(values=viewers)),
512
+ },
513
+ )
514
+ for d in _uris
515
+ }
516
+ )
517
+ )
518
+ assert isinstance(response, PutDataspacesResponse), (
519
+ "Expected PutDataspacesResponse"
305
520
  )
306
- assert isinstance(response, PutDataspacesResponse), "Expected PutDataspacesResponse"
307
521
 
308
- assert len(response.success) == len(dataspace_uris), f"expected {len(dataspace_uris)} success's"
522
+ assert len(response.success) == len(dataspace_uris), (
523
+ f"expected {len(dataspace_uris)} success's"
524
+ )
309
525
 
310
526
  return response.success
311
527
 
312
- async def put_dataspaces_no_raise(self, *dataspace_uris: DataspaceURI):
528
+ async def put_dataspaces_no_raise(
529
+ self,
530
+ legaltags: list[str],
531
+ otherRelevantDataCountries: list[str],
532
+ owners: list[str],
533
+ viewers: list[str],
534
+ *dataspace_uris: DataspaceURI,
535
+ ):
313
536
  try:
314
- return await self.put_dataspaces(*dataspace_uris)
537
+ return await self.put_dataspaces(
538
+ legaltags, otherRelevantDataCountries, owners, viewers, *dataspace_uris
539
+ )
315
540
  except ETPError:
316
541
  pass
317
542
 
318
543
  async def delete_dataspaces(self, *dataspace_uris: DataspaceURI):
319
-
320
-
321
544
  _uris = list(map(str, dataspace_uris))
322
545
 
323
546
  response = await self.send(DeleteDataspaces(uris=dict(zip(_uris, _uris))))
324
- assert isinstance(response, DeleteDataspacesResponse), "Expected DeleteDataspacesResponse"
547
+ assert isinstance(response, DeleteDataspacesResponse), (
548
+ "Expected DeleteDataspacesResponse"
549
+ )
325
550
  return response.success
326
551
 
327
552
  #
@@ -329,597 +554,276 @@ class ETPClient(ETPConnection):
329
554
  #
330
555
 
331
556
  async def get_data_objects(self, *uris: T.Union[DataObjectURI, str]):
332
-
333
557
  _uris = list(map(str, uris))
334
558
 
335
- msg = await self.send(
336
- GetDataObjects(uris=dict(zip(_uris, _uris)))
337
- )
559
+ msg = await self.send(GetDataObjects(uris=dict(zip(_uris, _uris))))
338
560
  assert isinstance(msg, GetDataObjectsResponse), "Expected dataobjectsresponse"
339
- assert len(msg.data_objects) == len(_uris), "Here we assume that all three objects fit in a single record"
561
+ assert len(msg.data_objects) == len(_uris), (
562
+ "Here we assume that all three objects fit in a single record"
563
+ )
340
564
 
341
565
  return [msg.data_objects[u] for u in _uris]
342
566
 
343
567
  async def put_data_objects(self, *objs: DataObject):
344
-
345
-
346
-
347
568
  response = await self.send(
348
- PutDataObjects(dataObjects={f"{p.resource.name}_{short_id()}": p for p in objs})
569
+ PutDataObjects(
570
+ data_objects={f"{p.resource.name} - {p.resource.uri}": p for p in objs},
571
+ )
572
+ )
573
+
574
+ assert isinstance(response, PutDataObjectsResponse), (
575
+ "Expected PutDataObjectsResponse"
349
576
  )
350
- # logger.info(f"objects {response=:}")
351
- assert isinstance(response, PutDataObjectsResponse), "Expected PutDataObjectsResponse"
352
- # assert len(response.success) == len(objs) # might be 0 if objects exists
353
577
 
354
578
  return response.success
355
579
 
356
- async def get_resqml_objects(self, *uris: T.Union[DataObjectURI, str]) -> T.List[ro.AbstractObject]:
580
+ async def get_resqml_objects(
581
+ self, *uris: T.Union[DataObjectURI, str]
582
+ ) -> T.List[ro.AbstractObject]:
357
583
  data_objects = await self.get_data_objects(*uris)
358
- return utils_xml.parse_resqml_objects(data_objects)
359
-
360
- async def put_resqml_objects(self, *objs: ro.AbstractObject, dataspace_uri: DataspaceURI):
584
+ return [
585
+ parse_resqml_v201_object(data_object.data) for data_object in data_objects
586
+ ]
361
587
 
588
+ async def put_resqml_objects(
589
+ self, *objs: ro.AbstractObject, dataspace_uri: DataspaceURI
590
+ ):
362
591
  time = self.timestamp
363
592
  uris = [DataObjectURI.from_obj(dataspace_uri, obj) for obj in objs]
364
- dobjs = [DataObject(
365
- format="xml",
366
- data=utils_xml.resqml_to_xml(obj),
367
- resource=Resource(
368
- uri=uri.raw_uri,
369
- name=obj.citation.title if obj.citation else obj.__class__.__name__,
370
- lastChanged=time,
371
- storeCreated=time,
372
- storeLastWrite=time,
373
- activeStatus="Inactive", # type: ignore
374
- sourceCount=None,
375
- targetCount=None
593
+ dobjs = [
594
+ DataObject(
595
+ format="xml",
596
+ data=serialize_resqml_v201_object(obj),
597
+ resource=Resource(
598
+ uri=uri.raw_uri,
599
+ name=obj.citation.title if obj.citation else obj.__class__.__name__,
600
+ lastChanged=time,
601
+ storeCreated=time,
602
+ storeLastWrite=time,
603
+ activeStatus="Inactive", # type: ignore
604
+ sourceCount=None,
605
+ targetCount=None,
606
+ ),
376
607
  )
377
- ) for uri, obj in zip(uris, objs)]
608
+ for uri, obj in zip(uris, objs)
609
+ ]
378
610
 
379
- response = await self.put_data_objects(*dobjs)
611
+ _ = await self.put_data_objects(*dobjs)
380
612
  return uris
381
613
 
382
- async def delete_data_objects(self, *uris: T.Union[DataObjectURI, str], pruneContainedObjects=False):
383
-
384
-
614
+ async def delete_data_objects(
615
+ self, *uris: T.Union[DataObjectURI, str], prune_contained_objects=False
616
+ ):
385
617
  _uris = list(map(str, uris))
386
618
 
387
619
  response = await self.send(
388
620
  DeleteDataObjects(
389
621
  uris=dict(zip(_uris, _uris)),
390
- pruneContainedObjects=pruneContainedObjects
622
+ prune_contained_objects=prune_contained_objects,
391
623
  )
392
624
  )
393
- # logger.info(f"delete objects {response=:}")
394
- assert isinstance(response, DeleteDataObjectsResponse), "Expected DeleteDataObjectsResponse"
395
-
396
- return response.deleted_uris
397
- #
398
- # xtgeo
399
- #
400
- @staticmethod
401
- def check_inside(x: float, y: float, patch: ro.Grid2dPatch):
402
- xori = patch.geometry.points.supporting_geometry.origin.coordinate1
403
- yori = patch.geometry.points.supporting_geometry.origin.coordinate2
404
- xmax = xori + (patch.geometry.points.supporting_geometry.offset[0].spacing.value*patch.geometry.points.supporting_geometry.offset[0].spacing.count)
405
- ymax = yori + (patch.geometry.points.supporting_geometry.offset[1].spacing.value*patch.geometry.points.supporting_geometry.offset[1].spacing.count)
406
- if x < xori:
407
- return False
408
- if y < yori:
409
- return False
410
- if x > xmax:
411
- return False
412
- if y > ymax:
413
- return False
414
- return True
415
-
416
- @staticmethod
417
- def find_closest_index(x, y, patch: ro.Grid2dPatch):
418
- x_ind = (x-patch.geometry.points.supporting_geometry.origin.coordinate1)/patch.geometry.points.supporting_geometry.offset[0].spacing.value
419
- y_ind = (y-patch.geometry.points.supporting_geometry.origin.coordinate2)/patch.geometry.points.supporting_geometry.offset[1].spacing.value
420
- return round(x_ind), round(y_ind)
421
-
422
- async def get_surface_value_x_y(self, epc_uri: T.Union[DataObjectURI, str], gri_uri: T.Union[DataObjectURI, str],crs_uri: T.Union[DataObjectURI, str], x: T.Union[int, float], y: T.Union[int, float], method: T.Literal["bilinear", "nearest"]):
423
- gri, = await self.get_resqml_objects(gri_uri) # parallelized using subarray
424
- xori = gri.grid2d_patch.geometry.points.supporting_geometry.origin.coordinate1
425
- yori = gri.grid2d_patch.geometry.points.supporting_geometry.origin.coordinate2
426
- xinc = gri.grid2d_patch.geometry.points.supporting_geometry.offset[0].spacing.value
427
- yinc = gri.grid2d_patch.geometry.points.supporting_geometry.offset[1].spacing.value
428
- max_x_index_in_gri = gri.grid2d_patch.geometry.points.supporting_geometry.offset[0].spacing.count
429
- max_y_index_in_gri = gri.grid2d_patch.geometry.points.supporting_geometry.offset[1].spacing.count
430
- buffer = 4
431
- if not self.check_inside(x, y, gri.grid2d_patch):
432
- logger.info(f"Points not inside {x}:{y} {gri}")
433
- return
434
- uid = DataArrayIdentifier(
435
- uri=str(epc_uri), pathInResource=gri.grid2d_patch.geometry.points.zvalues.values.path_in_hdf_file
436
- )
437
- if max_x_index_in_gri <= 10 or max_y_index_in_gri <= 10:
438
- surf = await self.get_xtgeo_surface(epc_uri, gri_uri, crs_uri)
439
- return surf.get_value_from_xy((x, y), sampling=method)
440
-
441
- x_ind, y_ind = self.find_closest_index(x, y, gri.grid2d_patch)
442
- if method == "nearest":
443
- arr = await self.get_subarray(uid, [x_ind, y_ind], [1, 1])
444
- return arr[0][0]
445
- min_x_ind = max(x_ind-(buffer/2), 0)
446
- min_y_ind = max(y_ind-(buffer/2), 0)
447
- count_x = min(max_x_index_in_gri-min_x_ind, buffer)
448
- count_y = min(max_y_index_in_gri-min_y_ind, buffer)
449
- # shift start index to left if not enough buffer on right
450
- if count_x < buffer:
451
- x_index_to_add = 3 - count_x
452
- min_x_ind_new = max(0, min_x_ind-x_index_to_add)
453
- count_x = count_x + min_x_ind-min_x_ind_new+1
454
- min_x_ind = min_x_ind_new
455
- if count_y < buffer:
456
- y_index_to_add = 3 - count_y
457
- min_y_ind_new = max(0, min_y_ind-y_index_to_add)
458
- count_y = count_y + min_y_ind-min_y_ind_new+1
459
- min_y_ind = min_y_ind_new
460
- arr = await self.get_subarray(uid, [min_x_ind, min_y_ind], [count_x, count_y])
461
- new_x_ori = xori+(min_x_ind*xinc)
462
- new_y_ori = yori+(min_y_ind*yinc)
463
- regridded = RegularSurface(
464
- ncol=arr.shape[0],
465
- nrow=arr.shape[1],
466
- xori=new_x_ori,
467
- yori=new_y_ori,
468
- xinc=xinc,
469
- yinc=yinc,
470
- rotation=0.0,
471
- values=arr.flatten(),
625
+ assert isinstance(response, DeleteDataObjectsResponse), (
626
+ "Expected DeleteDataObjectsResponse"
472
627
  )
473
- return regridded.get_value_from_xy((x, y))
474
-
475
- async def get_xtgeo_surface(self, epc_uri: T.Union[DataObjectURI, str], gri_uri: T.Union[DataObjectURI, str], crs_uri: T.Union[DataObjectURI, str]):
476
- gri, crs, = await self.get_resqml_objects(gri_uri, crs_uri)
477
- rotation = crs.areal_rotation.value
478
- # some checks
479
628
 
480
- assert isinstance(gri, ro.Grid2dRepresentation), "obj must be Grid2DRepresentation"
481
- sgeo = gri.grid2d_patch.geometry.points.supporting_geometry # type: ignore
482
- if sys.version_info[1] != 10:
483
- assert isinstance(gri.grid2d_patch.geometry.points, ro.Point3dZValueArray), "Points must be Point3dZValueArray"
484
- assert isinstance(gri.grid2d_patch.geometry.points.zvalues, ro.DoubleHdf5Array), "Values must be DoubleHdf5Array"
485
- assert isinstance(gri.grid2d_patch.geometry.points.supporting_geometry, ro.Point3dLatticeArray), "Points support_geo must be Point3dLatticeArray"
486
- assert isinstance(sgeo, ro.Point3dLatticeArray), "supporting_geometry must be Point3dLatticeArray"
487
- assert isinstance(gri.grid2d_patch.geometry.points.zvalues.values, ro.Hdf5Dataset), "Values must be Hdf5Dataset"
629
+ return response.deleted_uris
488
630
 
489
- # get array
490
- array = await self.get_array(
491
- DataArrayIdentifier(
492
- uri=str(epc_uri), pathInResource=gri.grid2d_patch.geometry.points.zvalues.values.path_in_hdf_file
631
+ async def start_transaction(
632
+ self, dataspace_uri: DataspaceURI, read_only: bool = True
633
+ ) -> Uuid:
634
+ trans_id = await self.send(
635
+ StartTransaction(
636
+ read_only=read_only, dataspace_uris=[dataspace_uri.raw_uri]
493
637
  )
494
638
  )
495
-
496
- return RegularSurface(
497
- ncol=array.shape[0], nrow=array.shape[1],
498
- xinc=sgeo.offset[0].spacing.value, yinc=sgeo.offset[1].spacing.value, # type: ignore
499
- xori=sgeo.origin.coordinate1, yori=sgeo.origin.coordinate2,
500
- values=array, # type: ignore
501
- rotation=rotation,
502
- masked=True
503
- )
504
- async def start_transaction(self, dataspace_uri: DataspaceURI, readOnly :bool= True) -> Uuid:
505
- trans_id = await self.send(StartTransaction(readOnly=readOnly, dataspaceUris=[dataspace_uri.raw_uri]))
506
639
  if trans_id.successful is False:
507
640
  raise Exception(f"Failed starting transaction {dataspace_uri.raw_uri}")
508
- return Uuid(trans_id.transaction_uuid) #uuid.UUID(bytes=trans_id.transaction_uuid)
509
-
510
- async def commit_transaction(self, transaction_id: Uuid):
511
- r = await self.send(CommitTransaction(transactionUuid=transaction_id))
641
+ # uuid.UUID(bytes=trans_id.transaction_uuid)
642
+ return Uuid(trans_id.transaction_uuid)
643
+
644
+ async def commit_transaction(self, transaction_uuid: Uuid):
645
+ r = await self.send(CommitTransaction(transaction_uuid=transaction_uuid))
512
646
  if r.successful is False:
513
647
  raise Exception(r.failure_reason)
514
648
  return r
515
-
649
+
516
650
  async def rollback_transaction(self, transaction_id: Uuid):
517
651
  return await self.send(RollbackTransaction(transactionUuid=transaction_id))
518
-
519
- async def put_xtgeo_surface(self, surface: RegularSurface, epsg_code: int, dataspace_uri: DataspaceURI):
520
- """Returns (epc_uri, crs_uri, gri_uri)"""
521
- assert surface.values is not None, "cannot upload empty surface"
522
-
523
- t_id = await self.start_transaction(dataspace_uri, False)
524
- epc, crs, gri = utils_xml.parse_xtgeo_surface_to_resqml_grid(surface, epsg_code)
525
- epc_uri, crs_uri, gri_uri = await self.put_resqml_objects(epc, crs, gri, dataspace_uri=dataspace_uri)
526
-
527
- await self.put_array(
528
- DataArrayIdentifier(
529
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
530
- pathInResource=gri.grid2d_patch.geometry.points.zvalues.values.path_in_hdf_file # type: ignore
531
- ),
532
- surface.values.filled(np.nan).astype(np.float32),
533
- t_id
534
- )
535
-
536
- return epc_uri, gri_uri, crs_uri
537
652
 
538
653
  #
539
- # resqpy meshes
654
+ # xtgeo
540
655
  #
541
-
542
- async def get_epc_mesh(self, epc_uri: T.Union[DataObjectURI, str], uns_uri: T.Union[DataObjectURI, str]):
543
- uns, = await self.get_resqml_objects(uns_uri)
544
-
656
+ async def get_xtgeo_surface(
657
+ self,
658
+ epc_uri: T.Union[DataObjectURI, str],
659
+ gri_uri: T.Union[DataObjectURI, str],
660
+ crs_uri: T.Union[DataObjectURI, str],
661
+ ):
662
+ (
663
+ gri,
664
+ crs,
665
+ ) = await self.get_resqml_objects(gri_uri, crs_uri)
666
+ rotation = crs.areal_rotation.value
545
667
  # some checks
546
- assert isinstance(uns, ro.UnstructuredGridRepresentation), "obj must be UnstructuredGridRepresentation"
547
- assert isinstance(uns.geometry, ro.UnstructuredGridGeometry), "geometry must be UnstructuredGridGeometry"
548
- if sys.version_info[1] != 10:
549
- assert isinstance(uns.geometry.points, ro.Point3dHdf5Array), "points must be Point3dHdf5Array"
550
- assert isinstance(uns.geometry.faces_per_cell.elements, ro.IntegerHdf5Array), "faces_per_cell must be IntegerHdf5Array"
551
- assert isinstance(uns.geometry.faces_per_cell.cumulative_length, ro.IntegerHdf5Array), "faces_per_cell cl must be IntegerHdf5Array"
552
- assert isinstance(uns.geometry.points.coordinates, ro.Hdf5Dataset), "coordinates must be Hdf5Dataset"
553
668
 
554
- # # get array
555
- points = await self.get_array(
556
- DataArrayIdentifier(
557
- uri=str(epc_uri), pathInResource=uns.geometry.points.coordinates.path_in_hdf_file
558
- )
559
- )
560
- nodes_per_face = await self.get_array(
561
- DataArrayIdentifier(
562
- uri=str(epc_uri), pathInResource=uns.geometry.nodes_per_face.elements.values.path_in_hdf_file
563
- )
564
- )
565
- nodes_per_face_cl = await self.get_array(
566
- DataArrayIdentifier(
567
- uri=str(epc_uri), pathInResource=uns.geometry.nodes_per_face.cumulative_length.values.path_in_hdf_file
568
- )
569
- )
570
- faces_per_cell = await self.get_array(
571
- DataArrayIdentifier(
572
- uri=str(epc_uri), pathInResource=uns.geometry.faces_per_cell.elements.values.path_in_hdf_file
573
- )
574
- )
575
- faces_per_cell_cl = await self.get_array(
576
- DataArrayIdentifier(
577
- uri=str(epc_uri), pathInResource=uns.geometry.faces_per_cell.cumulative_length.values.path_in_hdf_file
578
- )
669
+ assert isinstance(gri, ro.Grid2dRepresentation), (
670
+ "obj must be Grid2DRepresentation"
579
671
  )
580
- cell_face_is_right_handed = await self.get_array(
581
- DataArrayIdentifier(
582
- uri=str(epc_uri), pathInResource=uns.geometry.cell_face_is_right_handed.values.path_in_hdf_file
672
+ sgeo = gri.grid2d_patch.geometry.points.supporting_geometry # type: ignore
673
+ if sys.version_info[1] != 10:
674
+ assert isinstance(
675
+ gri.grid2d_patch.geometry.points, ro.Point3dZValueArray
676
+ ), "Points must be Point3dZValueArray"
677
+ assert isinstance(
678
+ gri.grid2d_patch.geometry.points.zvalues, ro.DoubleHdf5Array
679
+ ), "Values must be DoubleHdf5Array"
680
+ assert isinstance(
681
+ gri.grid2d_patch.geometry.points.supporting_geometry,
682
+ ro.Point3dLatticeArray,
683
+ ), "Points support_geo must be Point3dLatticeArray"
684
+ assert isinstance(sgeo, ro.Point3dLatticeArray), (
685
+ "supporting_geometry must be Point3dLatticeArray"
583
686
  )
584
- )
585
-
586
- return uns, points, nodes_per_face, nodes_per_face_cl, faces_per_cell, faces_per_cell_cl, cell_face_is_right_handed
587
-
588
- async def get_epc_mesh_property(self, epc_uri: T.Union[DataObjectURI, str], prop_uri: T.Union[DataObjectURI, str]):
589
- cprop0, = await self.get_resqml_objects(prop_uri)
590
-
591
- # some checks
592
- assert isinstance(cprop0, ro.ContinuousProperty) or isinstance(cprop0, ro.DiscreteProperty), "prop must be a Property"
593
- assert len(cprop0.patch_of_values) == 1, "property obj must have exactly one patch of values"
687
+ assert isinstance(
688
+ gri.grid2d_patch.geometry.points.zvalues.values, ro.Hdf5Dataset
689
+ ), "Values must be Hdf5Dataset"
594
690
 
595
- # # get array
596
- values = await self.get_array(
691
+ # get array
692
+ array = await self.get_array(
597
693
  DataArrayIdentifier(
598
- uri=str(epc_uri), pathInResource=cprop0.patch_of_values[0].values.values.path_in_hdf_file,
694
+ uri=str(epc_uri),
695
+ pathInResource=gri.grid2d_patch.geometry.points.zvalues.values.path_in_hdf_file,
599
696
  )
600
697
  )
601
698
 
602
- return cprop0, values
603
-
604
- @staticmethod
605
- def check_bound(points, x: float, y: float):
606
- if x > points[:, 0].max() or x < points[:, 0].min():
607
- return False
608
- if y > points[:, 1].max() or y < points[:, 1].min():
609
- return False
610
- return True
611
-
612
- async def get_epc_mesh_property_x_y(self, epc_uri: T.Union[DataObjectURI, str], uns_uri: T.Union[DataObjectURI, str], prop_uri: T.Union[DataObjectURI, str], x: float, y: float):
613
- uns, = await self.get_resqml_objects(uns_uri)
614
- points = await self.get_array(
615
- DataArrayIdentifier(uri=str(epc_uri), pathInResource=uns.geometry.points.coordinates.path_in_hdf_file))
616
- chk = self.check_bound(points, x, y)
617
- if chk == False:
618
- return None
619
- unique_y = np.unique(points[:, 1])
620
- y_smaller_sorted = np.sort(unique_y[np.argwhere(unique_y < y).flatten()])
621
- if y_smaller_sorted.size > 1:
622
- y_floor = y_smaller_sorted[-2]
623
- elif y_smaller_sorted.size == 1:
624
- y_floor = y_smaller_sorted[-1]
625
- else:
626
- pass
627
- y_larger_sorted = np.sort(unique_y[np.argwhere(unique_y > y).flatten()])
628
- if y_larger_sorted.size > 1:
629
- y_ceil = y_larger_sorted[1]
630
- elif y_larger_sorted.size == 1:
631
- y_ceil = y_larger_sorted[0]
632
- else:
633
- pass
634
- start_new_row_idx = np.argwhere(np.diff(points[:, 1]) != 0).flatten() + 1
635
-
636
- to_fetch = []
637
- initial_result_arr_idx = 0
638
- for i in range(start_new_row_idx.size-1):
639
- sliced = points[start_new_row_idx[i]:start_new_row_idx[i+1], :]
640
- if sliced[0, 1] <= y_ceil and sliced[0, 1] >= y_floor:
641
- # Found slice that has same y
642
- x_diff = sliced[:, 0]-x
643
- if all([np.any((x_diff >= 0)), np.any((x_diff <= 0))]): # y within this slice
644
- first_idx = start_new_row_idx[i]
645
- count = start_new_row_idx[i+1]-first_idx
646
- to_fetch.append([start_new_row_idx[i], start_new_row_idx[i+1], count, initial_result_arr_idx])
647
- initial_result_arr_idx += count
648
-
649
- total_points_filtered = sum([i[2] for i in to_fetch])
650
-
651
- cprop, = await self.get_resqml_objects(prop_uri)
652
- assert str(cprop.indexable_element) == 'IndexableElements.NODES'
653
- props_uid = DataArrayIdentifier(
654
- uri=str(epc_uri), pathInResource=cprop.patch_of_values[0].values.values.path_in_hdf_file)
655
- meta, = await self.get_array_metadata(props_uid)
656
- filtered_points = np.zeros((total_points_filtered, 3), dtype=np.float64)
657
- all_values = np.empty(total_points_filtered, dtype=np.float64)
658
-
659
- async def populate(i):
660
- end_indx = i[2]+i[3]
661
- filtered_points[i[3]:end_indx] = points[i[0]:i[1]]
662
- if utils_arrays.get_nbytes(meta) * i[2]/points.shape[0] > self.max_array_size:
663
- all_values[i[3]:end_indx] = await self._get_array_chuncked(props_uid, i[0], i[2])
664
- else:
665
- all_values[i[3]:end_indx] = await self.get_subarray(props_uid, [i[0]], [i[2]])
666
- return
667
-
668
- r = await asyncio.gather(*[populate(i) for i in to_fetch])
669
-
670
- if isinstance(cprop, ro.DiscreteProperty):
671
- method = "nearest"
672
- else:
673
- method = "linear"
674
-
675
- # resolution= np.mean(np.diff(filtered[:,-1]))
676
- top = round(np.min(filtered_points[:, -1]), 1)
677
- base = round(np.max(filtered_points[:, -1]), 1)
678
- requested_depth = np.arange(top, base+1, 100)
679
- requested_depth = requested_depth[requested_depth > 0]
680
- request = np.tile([x, y, 0], (requested_depth.size, 1))
681
- request[:, 2] = requested_depth
682
- interpolated = griddata(filtered_points, all_values, request, method=method)
683
- response = np.vstack((requested_depth, interpolated))
684
- response_filtered = response[:, ~np.isnan(response[1])]
685
- return {"depth": response_filtered[0], "values": response_filtered[1]}
686
-
687
- async def put_rddms_property(self, epc_uri: T.Union[DataObjectURI , str],
688
- cprop0: T.Union[ro.ContinuousProperty, ro.DiscreteProperty],
689
- propertykind0: ro.PropertyKind,
690
- array_ref: np.ndarray,
691
- dataspace_uri: DataspaceURI ):
692
-
693
- assert isinstance(cprop0, ro.ContinuousProperty) or isinstance(cprop0, ro.DiscreteProperty), "prop must be a Property"
694
- assert len(cprop0.patch_of_values) == 1, "property obj must have exactly one patch of values"
695
-
696
- st = time.time()
697
- t_id = await self.start_transaction(dataspace_uri, False)
698
- propkind_uri = [""] if (propertykind0 is None) else (await self.put_resqml_objects(propertykind0, dataspace_uri=dataspace_uri))
699
- cprop_uri = await self.put_resqml_objects(cprop0, dataspace_uri=dataspace_uri)
700
- delay = time.time() - st
701
- logger.debug(f"pyetp: put_rddms_property: put objects took {delay} s")
702
-
703
- st = time.time()
704
- response = await self.put_array(
705
- DataArrayIdentifier(
706
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
707
- pathInResource=cprop0.patch_of_values[0].values.values.path_in_hdf_file,
708
- ),
709
- array_ref, # type: ignore
710
- t_id
699
+ return RegularSurface(
700
+ ncol=array.shape[0],
701
+ nrow=array.shape[1],
702
+ xinc=sgeo.offset[0].spacing.value,
703
+ yinc=sgeo.offset[1].spacing.value,
704
+ xori=sgeo.origin.coordinate1,
705
+ yori=sgeo.origin.coordinate2,
706
+ values=array,
707
+ rotation=rotation,
708
+ masked=True,
711
709
  )
712
- delay = time.time() - st
713
- logger.debug(f"pyetp: put_rddms_property: put array ({array_ref.shape}) took {delay} s")
714
- return cprop_uri, propkind_uri
715
710
 
716
- async def put_epc_mesh(
717
- self, epc_filename: str, title_in: str, property_titles: T.List[str], projected_epsg: int,
718
- dataspace_uri: DataspaceURI
711
+ async def put_xtgeo_surface(
712
+ self,
713
+ surface: RegularSurface,
714
+ epsg_code: int,
715
+ dataspace_uri: DataspaceURI,
716
+ handle_transaction: bool = True,
719
717
  ):
720
- uns, crs, epc, timeseries, hexa = utils_xml.convert_epc_mesh_to_resqml_mesh(epc_filename, title_in, projected_epsg)
721
- t_id = await self.start_transaction(dataspace_uri, False)
722
- epc_uri, crs_uri, uns_uri = await self.put_resqml_objects(epc, crs, uns, dataspace_uri=dataspace_uri)
723
- timeseries_uri = ""
724
- if timeseries is not None:
725
- timeseries_uris = await self.put_resqml_objects(timeseries, dataspace_uri=dataspace_uri)
726
- timeseries_uri = list(timeseries_uris)[0] if (len(list(timeseries_uris)) > 0) else ""
727
-
728
- #
729
- # mesh geometry (six arrays)
730
- #
731
- response = await self.put_array(
732
- DataArrayIdentifier(
733
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
734
- pathInResource=uns.geometry.points.coordinates.path_in_hdf_file
735
- ),
736
- hexa.points_cached # type: ignore
737
- )
718
+ """Returns (epc_uri, crs_uri, gri_uri).
738
719
 
739
- response = await self.put_array(
740
- DataArrayIdentifier(
741
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
742
- pathInResource=uns.geometry.nodes_per_face.elements.values.path_in_hdf_file
743
- ),
744
- hexa.nodes_per_face.astype(np.int32) # type: ignore
745
- )
746
-
747
- response = await self.put_array(
748
- DataArrayIdentifier(
749
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
750
- pathInResource=uns.geometry.nodes_per_face.cumulative_length.values.path_in_hdf_file
751
- ),
752
- hexa.nodes_per_face_cl # type: ignore
753
- )
720
+ If `handle_transaction == True` we start a transaction, and commit it
721
+ after the data has been uploaded. Otherwise, we do not handle the
722
+ transactions at all and assume that the user will start and commit the
723
+ transaction themselves.
724
+ """
725
+ assert surface.values is not None, "cannot upload empty surface"
754
726
 
755
- response = await self.put_array(
756
- DataArrayIdentifier(
757
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
758
- pathInResource=uns.geometry.faces_per_cell.elements.values.path_in_hdf_file
759
- ),
760
- hexa.faces_per_cell # type: ignore
761
- )
727
+ if handle_transaction:
728
+ transaction_uuid = await self.start_transaction(
729
+ dataspace_uri, read_only=False
730
+ )
762
731
 
763
- response = await self.put_array(
764
- DataArrayIdentifier(
765
- uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
766
- pathInResource=uns.geometry.faces_per_cell.cumulative_length.values.path_in_hdf_file
767
- ),
768
- hexa.faces_per_cell_cl # type: ignore
732
+ epc, crs, gri = utils_xml.parse_xtgeo_surface_to_resqml_grid(surface, epsg_code)
733
+ epc_uri, crs_uri, gri_uri = await self.put_resqml_objects(
734
+ epc, crs, gri, dataspace_uri=dataspace_uri
769
735
  )
770
736
 
771
- response = await self.put_array(
737
+ await self.put_array(
772
738
  DataArrayIdentifier(
773
739
  uri=epc_uri.raw_uri if isinstance(epc_uri, DataObjectURI) else epc_uri,
774
- pathInResource=uns.geometry.cell_face_is_right_handed.values.path_in_hdf_file
740
+ pathInResource=gri.grid2d_patch.geometry.points.zvalues.values.path_in_hdf_file, # type: ignore
775
741
  ),
776
- hexa.cell_face_is_right_handed # type: ignore
742
+ surface.values.filled(np.nan).astype(np.float32),
777
743
  )
778
- await self.commit_transaction(t_id)
779
- #
780
- # mesh properties: one Property, one array of values, and an optional PropertyKind per property
781
- #
782
- prop_rddms_uris = {}
783
- for propname in property_titles:
784
- if timeseries is not None:
785
- time_indices = list(range(len(timeseries.time)))
786
- cprop0s, props, propertykind0 = utils_xml.convert_epc_mesh_property_to_resqml_mesh(epc_filename, hexa, propname, uns, epc, timeseries=timeseries, time_indices=time_indices)
787
- else:
788
- time_indices = [-1]
789
- cprop0s, props, propertykind0 = utils_xml.convert_epc_mesh_property_to_resqml_mesh(epc_filename, hexa, propname, uns, epc)
790
-
791
- if cprop0s is None:
792
- continue
793
744
 
794
- cprop_uris = []
795
- for cprop0, prop, time_index in zip(cprop0s, props, time_indices):
796
- cprop_uri, propkind_uri = await self.put_rddms_property(epc_uri, cprop0, propertykind0, prop.array_ref(), dataspace_uri)
797
- cprop_uris.extend(cprop_uri)
798
- prop_rddms_uris[propname] = [propkind_uri, cprop_uris]
745
+ if handle_transaction:
746
+ await self.commit_transaction(transaction_uuid=transaction_uuid)
799
747
 
800
- return [epc_uri, crs_uri, uns_uri, timeseries_uri], prop_rddms_uris
801
-
802
- async def get_mesh_points(self, epc_uri: T.Union[DataObjectURI, str], uns_uri: T.Union[DataObjectURI, str]):
803
- uns, = await self.get_resqml_objects(uns_uri)
804
- points = await self.get_array(
805
- DataArrayIdentifier(
806
- uri=str(epc_uri), pathInResource=uns.geometry.points.coordinates.path_in_hdf_file
807
- )
808
- )
809
- return points
810
-
811
- async def get_epc_property_surface_slice_node(self, epc_uri: T.Union[DataObjectURI, str], cprop0: ro.AbstractObject, points: np.ndarray, node_index: int, n_node_per_pos: int):
812
- # indexing_array = np.arange(0, points.shape[0], 1, dtype=np.int32)[node_index::n_node_per_pos]
813
- indexing_array = np.arange(node_index, points.shape[0], n_node_per_pos, dtype=np.int32)
814
- results = points[indexing_array, :]
815
- arr = await asyncio.gather(*[self.get_subarray(DataArrayIdentifier(
816
- uri=str(epc_uri), pathInResource=cprop0.patch_of_values[0].values.values.path_in_hdf_file,),
817
- [i], [1]) for i in indexing_array])
818
- arr = np.array(arr).flatten()
819
- assert results.shape[0] == arr.size
820
- results[:, 2] = arr
821
- return results
822
-
823
- async def get_epc_property_surface_slice_cell(self, epc_uri: T.Union[DataObjectURI, str], cprop0: ro.AbstractObject, points: np.ndarray, node_index: int, n_node_per_pos: int, get_cell_pos=True):
824
- m, = await self.get_array_metadata(DataArrayIdentifier(
825
- uri=str(epc_uri), pathInResource=cprop0.patch_of_values[0].values.values.path_in_hdf_file,))
826
- n_cells = m.dimensions[0]
827
- layers_per_sediment_unit = 2
828
- n_cell_per_pos = n_node_per_pos - 1
829
- indexing_array = np.arange(node_index, n_cells, n_cell_per_pos, dtype=np.int32)
830
- if get_cell_pos:
831
- results = utils_arrays.get_cells_positions(points, n_cells, n_cell_per_pos, layers_per_sediment_unit, n_node_per_pos, node_index)
832
- else:
833
- results = np.zeros((int(n_cells/n_cell_per_pos), 3), dtype=np.float64)
834
- arr = await asyncio.gather(*[self.get_subarray(DataArrayIdentifier(
835
- uri=str(epc_uri), pathInResource=cprop0.patch_of_values[0].values.values.path_in_hdf_file,),
836
- [i], [1]) for i in indexing_array])
837
- arr = np.array(arr).flatten()
838
- assert results.shape[0] == arr.size
839
- results[:, 2] = arr
840
- return results
841
-
842
- async def get_epc_property_surface_slice(self, epc_uri: T.Union[DataObjectURI, str], uns_uri: T.Union[DataObjectURI, str], prop_uri: T.Union[DataObjectURI, str], node_index: int, n_node_per_pos: int):
843
- # n_node_per_pos number of nodes in a 1D location
844
- # node_index index of slice from top. Warmth has 2 nodes per sediment layer. E.g. top of second layer will have index 2
845
- points = await self.get_mesh_points(epc_uri, uns_uri)
846
- cprop0, = await self.get_resqml_objects(prop_uri)
847
- prop_at_node = False
848
- if str(cprop0.indexable_element) == 'IndexableElements.NODES':
849
- prop_at_node = True
850
- # node_per_sed = 2
851
- # n_sed_node = n_sed *node_per_sed
852
- # n_crust_node = 4
853
- # n_node_per_pos = n_sed_node + n_crust_node
854
- # start_idx_pos = sediment_id *node_per_sed
855
- if prop_at_node:
856
- return await self.get_epc_property_surface_slice_node(epc_uri, cprop0, points, node_index, n_node_per_pos)
857
- else:
858
- return await self.get_epc_property_surface_slice_cell(epc_uri, cprop0, points, node_index, n_node_per_pos)
859
-
860
- async def get_epc_property_surface_slice_xtgeo(self, epc_uri: T.Union[DataObjectURI, str], uns_uri: T.Union[DataObjectURI, str], prop_uri: T.Union[DataObjectURI, str], node_index: int, n_node_per_pos: int):
861
- data = await self.get_epc_property_surface_slice(epc_uri, uns_uri, prop_uri, node_index, n_node_per_pos)
862
- return utils_arrays.grid_xtgeo(data)
748
+ return epc_uri, gri_uri, crs_uri
863
749
 
864
750
  #
865
751
  # array
866
752
  #
867
753
 
868
754
  async def get_array_metadata(self, *uids: DataArrayIdentifier):
869
-
870
-
871
755
  response = await self.send(
872
756
  GetDataArrayMetadata(dataArrays={i.path_in_resource: i for i in uids})
873
757
  )
874
758
  assert isinstance(response, GetDataArrayMetadataResponse)
875
759
 
876
760
  if len(response.array_metadata) != len(uids):
877
- raise ETPError(f'Not all uids found ({uids})', 11)
761
+ raise ETPError(f"Not all uids found ({uids})", 11)
878
762
 
879
763
  # return in same order as arguments
880
764
  return [response.array_metadata[i.path_in_resource] for i in uids]
881
765
 
882
766
  async def get_array(self, uid: DataArrayIdentifier):
883
-
884
-
885
- # Check if we can upload the full array in one go.
886
- meta, = await self.get_array_metadata(uid)
887
- if utils_arrays.get_nbytes(meta) > self.max_array_size:
888
- return await self._get_array_chuncked(uid)
767
+ # Check if we can download the full array in one go.
768
+ (meta,) = await self.get_array_metadata(uid)
769
+ if utils_arrays.get_transport_array_size(meta) > self.max_array_size:
770
+ return await self._get_array_chunked(uid)
889
771
 
890
772
  response = await self.send(
891
773
  GetDataArrays(dataArrays={uid.path_in_resource: uid})
892
774
  )
893
- assert isinstance(response, GetDataArraysResponse), "Expected GetDataArraysResponse"
775
+ assert isinstance(response, GetDataArraysResponse), (
776
+ "Expected GetDataArraysResponse"
777
+ )
894
778
 
895
779
  arrays = list(response.data_arrays.values())
896
- return utils_arrays.to_numpy(arrays[0])
897
-
898
- async def put_array(self, uid: DataArrayIdentifier, data: np.ndarray, transaction_id: Uuid | None = None):
899
-
780
+ return utils_arrays.get_numpy_array_from_etp_data_array(arrays[0])
900
781
 
782
+ async def put_array(
783
+ self,
784
+ uid: DataArrayIdentifier,
785
+ data: np.ndarray,
786
+ ):
787
+ logical_array_type, transport_array_type = (
788
+ utils_arrays.get_logical_and_transport_array_types(data.dtype)
789
+ )
790
+ await self._put_uninitialized_data_array(
791
+ uid,
792
+ data.shape,
793
+ transport_array_type=transport_array_type,
794
+ logical_array_type=logical_array_type,
795
+ )
901
796
  # Check if we can upload the full array in one go.
902
797
  if data.nbytes > self.max_array_size:
903
- return await self._put_array_chuncked(uid, data, transaction_id)
904
-
798
+ return await self._put_array_chunked(uid, data)
799
+
905
800
  response = await self.send(
906
801
  PutDataArrays(
907
- dataArrays={uid.path_in_resource: PutDataArraysType(uid=uid, array=utils_arrays.to_data_array(data))})
802
+ data_arrays={
803
+ uid.path_in_resource: PutDataArraysType(
804
+ uid=uid,
805
+ array=utils_arrays.get_etp_data_array_from_numpy(data),
806
+ )
807
+ }
808
+ )
908
809
  )
909
810
 
910
- assert isinstance(response, PutDataArraysResponse), "Expected PutDataArraysResponse"
811
+ assert isinstance(response, PutDataArraysResponse), (
812
+ "Expected PutDataArraysResponse"
813
+ )
911
814
  assert len(response.success) == 1, "expected one success from put_array"
912
- if isinstance(transaction_id, Uuid):
913
- await self.commit_transaction(transaction_id)
914
- return response.success
915
815
 
816
+ return response.success
916
817
 
917
- async def get_subarray(self, uid: DataArrayIdentifier, starts: T.Union[np.ndarray, T.List[int]], counts: T.Union[np.ndarray, T.List[int]]):
818
+ async def get_subarray(
819
+ self,
820
+ uid: DataArrayIdentifier,
821
+ starts: T.Union[np.ndarray, T.List[int]],
822
+ counts: T.Union[np.ndarray, T.List[int]],
823
+ ):
918
824
  starts = np.array(starts).astype(np.int64)
919
825
  counts = np.array(counts).astype(np.int64)
920
826
 
921
-
922
-
923
827
  logger.debug(f"get_subarray {starts=:} {counts=:}")
924
828
 
925
829
  payload = GetDataSubarraysType(
@@ -930,23 +834,32 @@ class ETPClient(ETPConnection):
930
834
  response = await self.send(
931
835
  GetDataSubarrays(dataSubarrays={uid.path_in_resource: payload})
932
836
  )
933
- assert isinstance(response, GetDataSubarraysResponse), "Expected GetDataSubarraysResponse"
837
+ assert isinstance(response, GetDataSubarraysResponse), (
838
+ "Expected GetDataSubarraysResponse"
839
+ )
934
840
 
935
841
  arrays = list(response.data_subarrays.values())
936
- return utils_arrays.to_numpy(arrays[0])
937
-
938
- async def put_subarray(self, uid: DataArrayIdentifier, data: np.ndarray, starts: T.Union[np.ndarray, T.List[int]], counts: T.Union[np.ndarray, T.List[int]]):
939
-
842
+ return utils_arrays.get_numpy_array_from_etp_data_array(arrays[0])
843
+
844
+ async def put_subarray(
845
+ self,
846
+ uid: DataArrayIdentifier,
847
+ data: np.ndarray,
848
+ starts: T.Union[np.ndarray, T.List[int]],
849
+ counts: T.Union[np.ndarray, T.List[int]],
850
+ ):
851
+ # NOTE: This function assumes that the user (or previous methods) have
852
+ # called _put_uninitialized_data_array.
940
853
 
941
854
  # starts [start_X, starts_Y]
942
855
  # counts [count_X, count_Y]
943
- starts = np.array(starts).astype(np.int64) # len = 2 [x_start_index, y_start_index]
944
- counts = np.array(counts).astype(np.int64) # len = 2
945
- ends = starts + counts # len = 2
946
-
856
+ # len = 2 [x_start_index, y_start_index]
857
+ starts = np.array(starts).astype(np.int64)
858
+ counts = np.array(counts).astype(np.int64) # len = 2
859
+ ends = starts + counts # len = 2
947
860
 
948
861
  slices = tuple(map(lambda se: slice(se[0], se[1]), zip(starts, ends)))
949
- dataarray = utils_arrays.to_data_array(data[slices])
862
+ dataarray = utils_arrays.get_etp_data_array_from_numpy(data[slices])
950
863
  payload = PutDataSubarraysType(
951
864
  uid=uid,
952
865
  data=dataarray.data,
@@ -954,25 +867,33 @@ class ETPClient(ETPConnection):
954
867
  counts=counts.tolist(),
955
868
  )
956
869
 
957
- logger.debug(f"put_subarray {data.shape=:} {starts=:} {counts=:} {dataarray.data.item.__class__.__name__}")
870
+ logger.debug(
871
+ f"put_subarray {data.shape=:} {starts=:} {counts=:} "
872
+ f"{dataarray.data.item.__class__.__name__}"
873
+ )
958
874
 
959
875
  response = await self.send(
960
876
  PutDataSubarrays(dataSubarrays={uid.path_in_resource: payload})
961
877
  )
962
- assert isinstance(response, PutDataSubarraysResponse), "Expected PutDataSubarraysResponse"
878
+ assert isinstance(response, PutDataSubarraysResponse), (
879
+ "Expected PutDataSubarraysResponse"
880
+ )
963
881
  assert len(response.success) == 1, "expected one success"
964
882
  return response.success
965
883
 
966
884
  #
967
- # chuncked get array - ETP will not chunck response - so we need to do it manually
885
+ # chunked get array - ETP will not chunk response - so we need to do it manually
968
886
  #
969
887
 
970
- def _get_chunk_sizes(self, shape, dtype: np.dtype[T.Any] = np.dtype(np.float32), offset=0):
888
+ def _get_chunk_sizes(
889
+ self, shape, dtype: np.dtype[T.Any] = np.dtype(np.float32), offset=0
890
+ ):
971
891
  shape = np.array(shape)
972
892
 
973
893
  # capsize blocksize
974
- max_items = self.max_array_size / dtype.itemsize # remove 512 bytes for headers and body
975
- block_size = np.power(max_items, 1. / len(shape))
894
+ # remove 512 bytes for headers and body
895
+ max_items = self.max_array_size / dtype.itemsize
896
+ block_size = np.power(max_items, 1.0 / len(shape))
976
897
  block_size = min(2048, int(block_size // 2) * 2)
977
898
 
978
899
  assert block_size > 8, "computed blocksize unreasonable small"
@@ -991,8 +912,21 @@ class ETPClient(ETPConnection):
991
912
  continue
992
913
  yield starts, counts
993
914
 
994
- async def _get_array_chuncked(self, uid: DataArrayIdentifier, offset: int = 0, total_count: T.Union[int, None] = None):
915
+ async def _get_array_chuncked(self, *args, **kwargs):
916
+ warnings.warn(
917
+ "This function is deprecated and will be removed in a later version of "
918
+ "pyetp. Please use the updated function 'pyetp._get_array_chunked'.",
919
+ DeprecationWarning,
920
+ stacklevel=2,
921
+ )
922
+ return self._get_array_chunked(*args, **kwargs)
995
923
 
924
+ async def _get_array_chunked(
925
+ self,
926
+ uid: DataArrayIdentifier,
927
+ offset: int = 0,
928
+ total_count: T.Union[int, None] = None,
929
+ ):
996
930
  metadata = (await self.get_array_metadata(uid))[0]
997
931
  if len(metadata.dimensions) != 1 and offset != 0:
998
932
  raise Exception("Offset is only implemented for 1D array")
@@ -1001,7 +935,9 @@ class ETPClient(ETPConnection):
1001
935
  buffer_shape = np.array([total_count], dtype=np.int64)
1002
936
  else:
1003
937
  buffer_shape = np.array(metadata.dimensions, dtype=np.int64)
1004
- dtype = utils_arrays.get_dtype(metadata.transport_array_type)
938
+ dtype = utils_arrays.get_dtype_from_any_array_type(
939
+ metadata.transport_array_type
940
+ )
1005
941
  buffer = np.zeros(buffer_shape, dtype=dtype)
1006
942
  params = []
1007
943
 
@@ -1009,62 +945,67 @@ class ETPClient(ETPConnection):
1009
945
  params.append([starts, counts])
1010
946
  array = await self.get_subarray(uid, starts, counts)
1011
947
  ends = starts + counts
1012
- slices = tuple(map(lambda se: slice(se[0], se[1]), zip(starts-offset, ends-offset)))
948
+ slices = tuple(
949
+ map(lambda se: slice(se[0], se[1]), zip(starts - offset, ends - offset))
950
+ )
1013
951
  buffer[slices] = array
1014
952
  return
1015
- # coro = [populate(starts, counts) for starts, counts in self._get_chunk_sizes(buffer_shape, dtype, offset)]
1016
- # logger.debug(f"Concurrent request: {self.max_concurrent_requests}")
1017
- # for i in batched(coro, self.max_concurrent_requests):
1018
- # await asyncio.gather(*i)
1019
- r = await asyncio.gather(*[
1020
- populate(starts, counts)
1021
- for starts, counts in self._get_chunk_sizes(buffer_shape, dtype, offset)
1022
- ])
1023
953
 
1024
- return buffer
954
+ _ = await asyncio.gather(
955
+ *[
956
+ populate(starts, counts)
957
+ for starts, counts in self._get_chunk_sizes(buffer_shape, dtype, offset)
958
+ ]
959
+ )
1025
960
 
1026
- async def _put_array_chuncked(self, uid: DataArrayIdentifier, data: np.ndarray, transaction_id: Uuid | None = None):
1027
- transport_array_type = utils_arrays.get_transport(data.dtype)
961
+ return buffer
1028
962
 
1029
- await self._put_uninitialized_data_array(uid, data.shape, transport_array_type=transport_array_type)
1030
- if isinstance(transaction_id, Uuid):
1031
- await self.commit_transaction(transaction_id)
963
+ async def _put_array_chuncked(self, *args, **kwargs):
964
+ warnings.warn(
965
+ "This function is deprecated and will be removed in a later version of "
966
+ "pyetp. Please use the updated function 'pyetp._put_array_chunked'.",
967
+ DeprecationWarning,
968
+ stacklevel=2,
969
+ )
970
+ return self._put_array_chunked(*args, **kwargs)
1032
971
 
1033
- ds_uri = DataspaceURI.from_any(uid.uri)
1034
- t_id = None
1035
- if isinstance(transaction_id, Uuid):
1036
- t_id = await self.start_transaction(ds_uri, False)
972
+ async def _put_array_chunked(self, uid: DataArrayIdentifier, data: np.ndarray):
1037
973
  for starts, counts in self._get_chunk_sizes(data.shape, data.dtype):
1038
974
  await self.put_subarray(uid, data, starts, counts)
1039
- if isinstance(t_id, Uuid):
1040
- await self.commit_transaction(t_id)
1041
-
1042
- return {uid.uri: ''}
1043
-
1044
- async def _put_uninitialized_data_array(self, uid: DataArrayIdentifier, shape: T.Tuple[int, ...], transport_array_type=AnyArrayType.ARRAY_OF_FLOAT, logical_array_type=AnyLogicalArrayType.ARRAY_OF_BOOLEAN):
1045
975
 
976
+ return {uid.uri: ""}
1046
977
 
978
+ async def _put_uninitialized_data_array(
979
+ self,
980
+ uid: DataArrayIdentifier,
981
+ shape: T.Tuple[int, ...],
982
+ transport_array_type: AnyArrayType,
983
+ logical_array_type: AnyLogicalArrayType,
984
+ ):
1047
985
  payload = PutUninitializedDataArrayType(
1048
986
  uid=uid,
1049
- metadata=(DataArrayMetadata(
1050
- dimensions=list(shape), # type: ignore
1051
- transportArrayType=transport_array_type,
1052
- logicalArrayType=logical_array_type,
1053
- storeLastWrite=self.timestamp,
1054
- storeCreated=self.timestamp,
1055
- ))
987
+ metadata=(
988
+ DataArrayMetadata(
989
+ dimensions=list(shape), # type: ignore
990
+ transportArrayType=transport_array_type,
991
+ logicalArrayType=logical_array_type,
992
+ storeLastWrite=self.timestamp,
993
+ storeCreated=self.timestamp,
994
+ )
995
+ ),
1056
996
  )
1057
997
  response = await self.send(
1058
998
  PutUninitializedDataArrays(dataArrays={uid.path_in_resource: payload})
1059
999
  )
1060
- assert isinstance(response, PutUninitializedDataArraysResponse), "Expected PutUninitializedDataArraysResponse"
1000
+ assert isinstance(response, PutUninitializedDataArraysResponse), (
1001
+ "Expected PutUninitializedDataArraysResponse"
1002
+ )
1061
1003
  assert len(response.success) == 1, "expected one success"
1062
1004
  return response.success
1063
1005
 
1064
1006
 
1065
1007
  # define an asynchronous context manager
1066
1008
  class connect:
1067
-
1068
1009
  def __init__(self, authorization: T.Optional[SecretStr] = None):
1069
1010
  self.server_url = SETTINGS.etp_url
1070
1011
  self.authorization = authorization
@@ -1079,7 +1020,6 @@ class connect:
1079
1020
  # async with connect(...) as ...:
1080
1021
 
1081
1022
  async def __aenter__(self):
1082
-
1083
1023
  headers = {}
1084
1024
  if isinstance(self.authorization, str):
1085
1025
  headers["Authorization"] = self.authorization
@@ -1088,16 +1028,16 @@ class connect:
1088
1028
  if self.data_partition is not None:
1089
1029
  headers["data-partition-id"] = self.data_partition
1090
1030
 
1091
- ws = await websockets.connect(
1031
+ self.ws = await websockets.connect(
1092
1032
  self.server_url,
1093
1033
  subprotocols=[ETPClient.SUB_PROTOCOL], # type: ignore
1094
- extra_headers=headers,
1034
+ additional_headers=headers,
1095
1035
  max_size=SETTINGS.MaxWebSocketMessagePayloadSize,
1096
1036
  ping_timeout=self.timeout,
1097
1037
  open_timeout=None,
1098
1038
  )
1099
1039
 
1100
- self.client = ETPClient(ws, timeout=self.timeout)
1040
+ self.client = ETPClient(self.ws, timeout=self.timeout)
1101
1041
 
1102
1042
  try:
1103
1043
  await self.client.request_session()
@@ -1111,3 +1051,4 @@ class connect:
1111
1051
  # exit the async context manager
1112
1052
  async def __aexit__(self, exc_type, exc: Exception, tb: TracebackType):
1113
1053
  await self.client.close()
1054
+ await self.ws.close()