pyetp 0.0.43__py3-none-any.whl → 0.0.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyetp/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from ._version import __version__
2
- from .client import ETPClient, ETPError, connect
2
+ from .client import ETPClient, ETPError, connect, etp_connect, etp_persistent_connect
3
3
  from .uri import DataObjectURI, DataspaceURI
4
4
 
5
5
  __all__ = [
@@ -7,6 +7,8 @@ __all__ = [
7
7
  "ETPClient",
8
8
  "ETPError",
9
9
  "connect",
10
+ "etp_connect",
11
+ "etp_persistent_connect",
10
12
  "DataObjectURI",
11
13
  "DataspaceURI",
12
14
  ]
pyetp/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.43'
32
- __version_tuple__ = version_tuple = (0, 0, 43)
31
+ __version__ = version = '0.0.44'
32
+ __version_tuple__ = version_tuple = (0, 0, 44)
33
33
 
34
34
  __commit_id__ = commit_id = None
pyetp/client.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import contextlib
2
3
  import datetime
3
4
  import logging
4
5
  import sys
@@ -6,10 +7,13 @@ import typing as T
6
7
  import uuid
7
8
  import warnings
8
9
  from collections import defaultdict
10
+ from collections.abc import AsyncGenerator
9
11
  from types import TracebackType
10
12
 
11
13
  import numpy as np
14
+ import numpy.typing as npt
12
15
  import websockets
16
+ import websockets.client
13
17
  from etpproto.connection import CommunicationProtocol, ConnectionType, ETPConnection
14
18
  from etpproto.messages import Message, MessageFlags
15
19
  from etptypes import ETPModel
@@ -143,6 +147,7 @@ from xtgeo import RegularSurface
143
147
 
144
148
  import resqml_objects.v201 as ro
145
149
  from pyetp import utils_arrays, utils_xml
150
+ from pyetp._version import version
146
151
  from pyetp.config import SETTINGS
147
152
  from pyetp.uri import DataObjectURI, DataspaceURI
148
153
  from resqml_objects import parse_resqml_v201_object, serialize_resqml_v201_object
@@ -175,6 +180,14 @@ except ImportError:
175
180
  except asyncio.CancelledError as e:
176
181
  raise asyncio.TimeoutError(f"Timeout ({delay}s)") from e
177
182
 
183
+ TimeoutError = asyncio.TimeoutError
184
+
185
+ try:
186
+ # Python >= 3.11
187
+ from typing import Self
188
+ except ImportError:
189
+ Self = "ETPClient"
190
+
178
191
 
179
192
  class ETPError(Exception):
180
193
  def __init__(self, message: str, code: int):
@@ -192,6 +205,10 @@ class ETPError(Exception):
192
205
  return list(map(cls.from_proto, errors))
193
206
 
194
207
 
208
+ class ReceiveWorkerExited(Exception):
209
+ pass
210
+
211
+
195
212
  def get_all_etp_protocol_classes():
196
213
  """Update protocol - all exception protocols are now per message"""
197
214
 
@@ -211,17 +228,31 @@ class ETPClient(ETPConnection):
211
228
  _recv_events: T.Dict[int, asyncio.Event]
212
229
  _recv_buffer: T.Dict[int, T.List[ETPModel]]
213
230
 
214
- def __init__(self, ws: websockets.ClientConnection, timeout=10.0):
231
+ def __init__(
232
+ self,
233
+ ws: websockets.ClientConnection,
234
+ etp_timeout: float | None = 10.0,
235
+ max_message_size: float = 2**20,
236
+ application_name: str = "pyetp",
237
+ application_version: str = version,
238
+ ) -> None:
215
239
  super().__init__(connection_type=ConnectionType.CLIENT)
240
+
241
+ self.application_name = application_name
242
+ self.application_version = application_version
243
+
216
244
  self._recv_events = {}
217
245
  self._recv_buffer = defaultdict(lambda: list()) # type: ignore
218
246
  self.ws = ws
219
247
 
220
- self.timeout = timeout
248
+ # Ensure a minimum timeout of 10 seconds.
249
+ self.etp_timeout = (
250
+ etp_timeout if etp_timeout is None or etp_timeout > 10.0 else 10.0
251
+ )
221
252
  self.client_info.endpoint_capabilities["MaxWebSocketMessagePayloadSize"] = (
222
- SETTINGS.MaxWebSocketMessagePayloadSize
253
+ max_message_size
223
254
  )
224
- self.__recvtask = asyncio.create_task(self.__recv__())
255
+ self.__recvtask = asyncio.create_task(self.__recv())
225
256
 
226
257
  #
227
258
  # client
@@ -242,25 +273,81 @@ class ETPClient(ETPConnection):
242
273
  # create future recv event
243
274
  self._recv_events[msg.header.message_id] = asyncio.Event()
244
275
 
276
+ tasks = []
245
277
  for msg_part in msg.encode_message_generator(self.max_size, self):
246
- await self.ws.send(msg_part)
278
+ tasks.append(self.ws.send(msg_part))
279
+
280
+ await asyncio.gather(*tasks)
247
281
 
248
282
  return msg.header.message_id
249
283
 
250
284
  async def _recv(self, correlation_id: int) -> ETPModel:
251
285
  assert correlation_id in self._recv_events, (
252
- "trying to recv response on non-existing message"
286
+ "Trying to receive a response on non-existing message"
253
287
  )
254
288
 
255
- async with timeout(self.timeout):
256
- await self._recv_events[correlation_id].wait()
289
+ def timeout_intervals(etp_timeout):
290
+ # Local function generating progressively longer timeout intervals.
291
+
292
+ # Use the timeout-interval generator from the Python websockets
293
+ # library.
294
+ backoff_generator = websockets.client.backoff(
295
+ initial_delay=5.0, min_delay=5.0, max_delay=20.0
296
+ )
297
+
298
+ # Check if we should never time out.
299
+ if etp_timeout is None:
300
+ # This is an infinite generator, so it should never exit.
301
+ yield from backoff_generator
302
+ return
303
+
304
+ # Generate timeout intervals until we have reached the
305
+ # `etp_timeout`-threshold.
306
+ csum = 0.0
307
+ for d in backoff_generator:
308
+ yield d
309
+
310
+ csum += d
311
+
312
+ if csum >= etp_timeout:
313
+ break
314
+
315
+ for ti in timeout_intervals(self.etp_timeout):
316
+ try:
317
+ # Wait for an event for `ti` seconds.
318
+ async with timeout(ti):
319
+ await self._recv_events[correlation_id].wait()
320
+ except TimeoutError:
321
+ # Check if the receiver task is still running.
322
+ if self.__recvtask.done():
323
+ # Raise any errors by waiting for the task to finish.
324
+ await self.__recvtask
325
+
326
+ logger.error(
327
+ "Receiver task terminated without errors. This should not happen"
328
+ )
329
+
330
+ raise ReceiveWorkerExited
331
+ else:
332
+ # Break out of for-loop, and start processing message.
333
+ break
334
+ else:
335
+ # The for-loop finished without breaking. In other words, we have
336
+ # timed out.
337
+ assert self.etp_timeout is not None
338
+ raise TimeoutError(
339
+ f"Receiver task did not set event within {self.etp_timeout} seconds"
340
+ )
257
341
 
258
- # cleanup
259
- bodies = self._clear_msg_on_buffer(correlation_id)
342
+ # Remove event from list of events
343
+ del self._recv_events[correlation_id]
344
+ # Read message bodies from buffer.
345
+ bodies = self._recv_buffer.pop(correlation_id)
260
346
 
261
- # error handling
347
+ # Check if there are errors in the received messages.
262
348
  errors = self._parse_error_info(bodies)
263
349
 
350
+ # Raise errors in case there are any.
264
351
  if len(errors) == 1:
265
352
  raise ETPError.from_proto(errors.pop())
266
353
  elif len(errors) > 1:
@@ -285,9 +372,14 @@ class ETPClient(ETPConnection):
285
372
  errors.extend(body.errors.values())
286
373
  return errors
287
374
 
375
+ async def __aexit__(self, *exc_details) -> None:
376
+ await self.close(reason="Client exiting")
377
+
288
378
  async def close(self, reason=""):
379
+ close_session_sent = False
289
380
  try:
290
381
  await self._send(CloseSession(reason=reason))
382
+ close_session_sent = True
291
383
  except websockets.ConnectionClosed:
292
384
  logger.error(
293
385
  "Websockets connection is closed, unable to send a CloseSession-message"
@@ -324,12 +416,23 @@ class ETPClient(ETPConnection):
324
416
  # Reading them will speed up the closing of the connection.
325
417
  counter = 0
326
418
  try:
327
- async for msg in self.ws:
328
- counter += 1
419
+ # In some cases the server does not drop the connection after we
420
+ # have sent the `CloseSession`-message. We therefore add a timeout
421
+ # to the reading of possibly lost messages.
422
+ async with timeout(self.etp_timeout or 10):
423
+ async for msg in self.ws:
424
+ counter += 1
329
425
  except websockets.ConnectionClosed:
330
- # The websockets connection had already closed. Either successfully
426
+ # The websockets connection has already closed. Either successfully
331
427
  # or with an error, but we ignore both cases.
332
428
  pass
429
+ except TimeoutError:
430
+ if close_session_sent:
431
+ logger.error(
432
+ "Websockets connection was not closed within "
433
+ f"{self.etp_timeout or 10} seconds after the "
434
+ "`CloseSession`-message was sent"
435
+ )
333
436
 
334
437
  if counter > 0:
335
438
  logger.error(
@@ -339,26 +442,16 @@ class ETPClient(ETPConnection):
339
442
 
340
443
  logger.debug("Client closed")
341
444
 
342
- #
343
- #
344
- #
345
-
346
- def _clear_msg_on_buffer(self, correlation_id: int):
347
- del self._recv_events[correlation_id]
348
- return self._recv_buffer.pop(correlation_id)
349
-
350
- def _add_msg_to_buffer(self, msg: Message):
351
- self._recv_buffer[msg.header.correlation_id].append(msg.body)
352
-
353
- # NOTE: should we add task to autoclear buffer message if never waited on ?
354
- if msg.is_final_msg():
355
- # set response on send event
356
- self._recv_events[msg.header.correlation_id].set()
357
-
358
- async def __recv__(self):
359
- logger.debug("starting recv loop")
445
+ async def __recv(self):
446
+ logger.debug("Starting receiver loop")
360
447
 
361
448
  while True:
449
+ # We use this way of receiving messages, instead of the `async
450
+ # for`-pattern, in order to raise all
451
+ # `websockets.exceptions.ConnectionClosed`-errors (including the
452
+ # `websockets.exceptions.ConnectionClosedOK` error). In the `async
453
+ # for`-case a closing code of `1000` (normal closing) just exits
454
+ # the loop.
362
455
  msg_data = await self.ws.recv()
363
456
  msg = Message.decode_binary_message(
364
457
  T.cast(bytes, msg_data), ETPClient.generic_transition_table
@@ -369,12 +462,19 @@ class ETPClient(ETPConnection):
369
462
  continue
370
463
 
371
464
  logger.debug(f"recv {msg.body.__class__.__name__} {repr(msg.header)}")
372
- self._add_msg_to_buffer(msg)
465
+ self._recv_buffer[msg.header.correlation_id].append(msg.body)
466
+
467
+ if msg.is_final_msg():
468
+ # set response on send event
469
+ self._recv_events[msg.header.correlation_id].set()
373
470
 
374
471
  #
375
472
  # session related
376
473
  #
377
474
 
475
+ async def __aenter__(self) -> Self:
476
+ return await self.request_session()
477
+
378
478
  async def request_session(self):
379
479
  # Handshake protocol
380
480
  etp_version = Version(major=1, minor=2, revision=0, patch=0)
@@ -390,8 +490,8 @@ class ETPClient(ETPConnection):
390
490
 
391
491
  msg = await self.send(
392
492
  RequestSession(
393
- applicationName=SETTINGS.application_name,
394
- applicationVersion=SETTINGS.application_version,
493
+ applicationName=self.application_name,
494
+ applicationVersion=self.application_version,
395
495
  clientInstanceId=uuid.uuid4(), # type: ignore
396
496
  requestedProtocols=[
397
497
  SupportedProtocol(
@@ -435,16 +535,23 @@ class ETPClient(ETPConnection):
435
535
 
436
536
  return msg
437
537
 
438
- #
538
+ @staticmethod
539
+ def assert_response(response: ETPModel, expected_type: T.Type[ETPModel]) -> None:
540
+ assert isinstance(response, expected_type), (
541
+ f"Expected {expected_type}, got {type(response)} with content {response}"
542
+ )
439
543
 
440
544
  @property
441
545
  def max_size(self):
442
- return SETTINGS.MaxWebSocketMessagePayloadSize
443
- # return self.client_info.getCapability("MaxWebSocketMessagePayloadSize")
546
+ return self.client_info.getCapability("MaxWebSocketMessagePayloadSize")
444
547
 
445
548
  @property
446
549
  def max_array_size(self):
447
- return self.max_size - 512 # maxsize - 512 bytes for header and body
550
+ if self.max_size <= 3000:
551
+ raise AttributeError(
552
+ "The maximum size of a websocket message must be greater than 3000"
553
+ )
554
+ return self.max_size - 3000 # maxsize - 3000 bytes for header and body
448
555
 
449
556
  @property
450
557
  def timestamp(self):
@@ -455,12 +562,12 @@ class ETPClient(ETPConnection):
455
562
  raise Exception("Max one / in dataspace name")
456
563
  return DataspaceURI.from_name(ds)
457
564
 
458
- def list_objects(self, dataspace_uri: DataspaceURI, depth: int = 1) -> list:
565
+ def list_objects(self, dataspace_uri: DataspaceURI | str, depth: int = 1) -> list:
459
566
  return self.send(
460
567
  GetResources(
461
568
  scope=ContextScopeKind.TARGETS_OR_SELF,
462
569
  context=ContextInfo(
463
- uri=dataspace_uri.raw_uri,
570
+ uri=str(dataspace_uri),
464
571
  depth=depth,
465
572
  dataObjectTypes=[],
466
573
  navigableEdges=RelationshipKind.PRIMARY,
@@ -549,33 +656,67 @@ class ETPClient(ETPConnection):
549
656
  )
550
657
  return response.success
551
658
 
552
- #
553
- # data objects
554
- #
555
-
556
659
  async def get_data_objects(self, *uris: T.Union[DataObjectURI, str]):
557
- _uris = list(map(str, uris))
660
+ tasks = []
661
+ for uri in uris:
662
+ task = self.send(GetDataObjects(uris={str(uri): str(uri)}))
663
+ tasks.append(task)
558
664
 
559
- msg = await self.send(GetDataObjects(uris=dict(zip(_uris, _uris))))
560
- assert isinstance(msg, GetDataObjectsResponse), "Expected dataobjectsresponse"
561
- assert len(msg.data_objects) == len(_uris), (
562
- "Here we assume that all three objects fit in a single record"
563
- )
665
+ responses = await asyncio.gather(*tasks)
666
+ assert len(responses) == len(uris)
667
+
668
+ data_objects = []
669
+ errors = []
670
+ for uri, response in zip(uris, responses):
671
+ if not isinstance(response, GetDataObjectsResponse):
672
+ errors.append(
673
+ TypeError(
674
+ "Expected GetDataObjectsResponse, got "
675
+ f"{response.__class__.__name} with content: {response}",
676
+ )
677
+ )
678
+ data_objects.append(response.data_objects[str(uri)])
679
+
680
+ if len(errors) > 0:
681
+ raise ExceptionGroup(
682
+ f"There were {len(errors)} errors in ETPClient.get_data_objects",
683
+ errors,
684
+ )
564
685
 
565
- return [msg.data_objects[u] for u in _uris]
686
+ return data_objects
566
687
 
567
688
  async def put_data_objects(self, *objs: DataObject):
568
- response = await self.send(
569
- PutDataObjects(
570
- data_objects={f"{p.resource.name} - {p.resource.uri}": p for p in objs},
689
+ tasks = []
690
+ for obj in objs:
691
+ task = self.send(
692
+ PutDataObjects(
693
+ data_objects={f"{obj.resource.name} - {obj.resource.uri}": obj},
694
+ ),
571
695
  )
572
- )
696
+ tasks.append(task)
573
697
 
574
- assert isinstance(response, PutDataObjectsResponse), (
575
- "Expected PutDataObjectsResponse"
576
- )
698
+ responses = await asyncio.gather(*tasks)
577
699
 
578
- return response.success
700
+ errors = []
701
+ for response in responses:
702
+ if not isinstance(response, PutDataObjectsResponse):
703
+ errors.append(
704
+ TypeError(
705
+ "Expected PutDataObjectsResponse, got "
706
+ f"{response.__class__.__name} with content: {response}",
707
+ )
708
+ )
709
+ if len(errors) > 0:
710
+ raise ExceptionGroup(
711
+ f"There were {len(errors)} errors in ETPClient.put_data_objects",
712
+ errors,
713
+ )
714
+
715
+ sucesses = {}
716
+ for response in responses:
717
+ sucesses = {**sucesses, **response.success}
718
+
719
+ return sucesses
579
720
 
580
721
  async def get_resqml_objects(
581
722
  self, *uris: T.Union[DataObjectURI, str]
@@ -766,7 +907,12 @@ class ETPClient(ETPConnection):
766
907
  async def get_array(self, uid: DataArrayIdentifier):
767
908
  # Check if we can download the full array in one go.
768
909
  (meta,) = await self.get_array_metadata(uid)
769
- if utils_arrays.get_transport_array_size(meta) > self.max_array_size:
910
+ if (
911
+ utils_arrays.get_transport_array_size(
912
+ meta.transport_array_type, meta.dimensions
913
+ )
914
+ > self.max_array_size
915
+ ):
770
916
  return await self._get_array_chunked(uid)
771
917
 
772
918
  response = await self.send(
@@ -779,6 +925,250 @@ class ETPClient(ETPConnection):
779
925
  arrays = list(response.data_arrays.values())
780
926
  return utils_arrays.get_numpy_array_from_etp_data_array(arrays[0])
781
927
 
928
+ async def download_array(
929
+ self,
930
+ epc_uri: str | DataObjectURI,
931
+ path_in_resource: str,
932
+ ) -> npt.NDArray[utils_arrays.LogicalArrayDTypes]:
933
+ # Create identifier for the data.
934
+ dai = DataArrayIdentifier(
935
+ uri=str(epc_uri),
936
+ path_in_resource=path_in_resource,
937
+ )
938
+
939
+ response = await self.send(
940
+ GetDataArrayMetadata(data_arrays={dai.path_in_resource: dai}),
941
+ )
942
+
943
+ self.assert_response(response, GetDataArrayMetadataResponse)
944
+ assert (
945
+ len(response.array_metadata) == 1
946
+ and dai.path_in_resource in response.array_metadata
947
+ )
948
+
949
+ metadata = response.array_metadata[dai.path_in_resource]
950
+
951
+ # Check if we can download the full array in a single message.
952
+ if (
953
+ utils_arrays.get_transport_array_size(
954
+ metadata.transport_array_type, metadata.dimensions
955
+ )
956
+ >= self.max_array_size
957
+ ):
958
+ transport_dtype = utils_arrays.get_dtype_from_any_array_type(
959
+ metadata.transport_array_type,
960
+ )
961
+ # NOTE: The logical array type is not yet supported by the
962
+ # open-etp-server. As such the transport array type will be actual
963
+ # array type used. We only add this call to prepare for when it
964
+ # will be used.
965
+ logical_dtype = utils_arrays.get_dtype_from_any_logical_array_type(
966
+ metadata.logical_array_type,
967
+ )
968
+ if logical_dtype != np.dtype(np.bool_):
969
+ # If this debug message is triggered we should test the
970
+ # mapping.
971
+ logger.debug(
972
+ "Logical array type has changed: "
973
+ f"{metadata.logical_array_type = }, with {logical_dtype = }"
974
+ )
975
+
976
+ # Create a buffer for the data.
977
+ data = np.zeros(metadata.dimensions, dtype=transport_dtype)
978
+
979
+ # Get list with starting indices in each block, and a list with the
980
+ # number of elements along each axis for each block.
981
+ block_starts, block_counts = utils_arrays.get_array_block_sizes(
982
+ data.shape, data.dtype, self.max_array_size
983
+ )
984
+
985
+ def data_subarrays_key(pir: str, i: int) -> str:
986
+ return pir + f" ({i})"
987
+
988
+ tasks = []
989
+ for i, (starts, counts) in enumerate(zip(block_starts, block_counts)):
990
+ task = self.send(
991
+ GetDataSubarrays(
992
+ data_subarrays={
993
+ data_subarrays_key(
994
+ dai.path_in_resource, i
995
+ ): GetDataSubarraysType(
996
+ uid=dai,
997
+ starts=starts,
998
+ counts=counts,
999
+ ),
1000
+ },
1001
+ ),
1002
+ )
1003
+ tasks.append(task)
1004
+
1005
+ responses = await asyncio.gather(*tasks)
1006
+
1007
+ data_blocks = []
1008
+ for i, response in enumerate(responses):
1009
+ self.assert_response(response, GetDataSubarraysResponse)
1010
+ assert (
1011
+ len(response.data_subarrays) == 1
1012
+ and data_subarrays_key(dai.path_in_resource, i)
1013
+ in response.data_subarrays
1014
+ )
1015
+
1016
+ data_block = utils_arrays.get_numpy_array_from_etp_data_array(
1017
+ response.data_subarrays[
1018
+ data_subarrays_key(dai.path_in_resource, i)
1019
+ ],
1020
+ )
1021
+ data_blocks.append(data_block)
1022
+
1023
+ for data_block, starts, counts in zip(
1024
+ data_blocks, block_starts, block_counts
1025
+ ):
1026
+ # Create slice-objects for each block.
1027
+ slices = tuple(
1028
+ map(
1029
+ lambda s, c: slice(s, s + c),
1030
+ np.array(starts).astype(int),
1031
+ np.array(counts).astype(int),
1032
+ )
1033
+ )
1034
+ data[slices] = data_block
1035
+
1036
+ # Return after fetching all sub arrays.
1037
+ return data
1038
+
1039
+ # Download the full array in one go.
1040
+ response = await self.send(
1041
+ GetDataArrays(data_arrays={dai.path_in_resource: dai}),
1042
+ )
1043
+
1044
+ self.assert_response(response, GetDataArraysResponse)
1045
+ assert (
1046
+ len(response.data_arrays) == 1
1047
+ and dai.path_in_resource in response.data_arrays
1048
+ )
1049
+
1050
+ return utils_arrays.get_numpy_array_from_etp_data_array(
1051
+ response.data_arrays[dai.path_in_resource]
1052
+ )
1053
+
1054
+ async def upload_array(
1055
+ self,
1056
+ epc_uri: str | DataObjectURI,
1057
+ path_in_resource: str,
1058
+ data: npt.NDArray[utils_arrays.LogicalArrayDTypes],
1059
+ ) -> None:
1060
+ # Fetch ETP logical and transport array types
1061
+ logical_array_type, transport_array_type = (
1062
+ utils_arrays.get_logical_and_transport_array_types(data.dtype)
1063
+ )
1064
+
1065
+ # Create identifier for the data.
1066
+ dai = DataArrayIdentifier(
1067
+ uri=str(epc_uri),
1068
+ path_in_resource=path_in_resource,
1069
+ )
1070
+
1071
+ # Get current time as a UTC-timestamp.
1072
+ now = self.timestamp
1073
+
1074
+ # Allocate space on server for the array.
1075
+ response = await self.send(
1076
+ PutUninitializedDataArrays(
1077
+ data_arrays={
1078
+ dai.path_in_resource: PutUninitializedDataArrayType(
1079
+ uid=dai,
1080
+ metadata=DataArrayMetadata(
1081
+ dimensions=list(data.shape),
1082
+ transport_array_type=transport_array_type,
1083
+ logical_array_type=logical_array_type,
1084
+ store_last_write=now,
1085
+ store_created=now,
1086
+ ),
1087
+ ),
1088
+ },
1089
+ ),
1090
+ )
1091
+
1092
+ self.assert_response(response, PutUninitializedDataArraysResponse)
1093
+ assert len(response.success) == 1 and dai.path_in_resource in response.success
1094
+
1095
+ # Check if we can upload the entire array in go, or if we need to
1096
+ # upload it in smaller blocks.
1097
+ if data.nbytes > self.max_array_size:
1098
+ tasks = []
1099
+
1100
+ # Get list with starting indices in each block, and a list with the
1101
+ # number of elements along each axis for each block.
1102
+ block_starts, block_counts = utils_arrays.get_array_block_sizes(
1103
+ data.shape, data.dtype, self.max_array_size
1104
+ )
1105
+
1106
+ for starts, counts in zip(block_starts, block_counts):
1107
+ # Create slice-objects for each block.
1108
+ slices = tuple(
1109
+ map(
1110
+ lambda s, c: slice(s, s + c),
1111
+ np.array(starts).astype(int),
1112
+ np.array(counts).astype(int),
1113
+ )
1114
+ )
1115
+
1116
+ # Slice the array, and convert to the relevant ETP-array type.
1117
+ # Note in the particular the extra `.data`-after the call. The
1118
+ # data should not be of type `DataArray`, but `AnyArray`, so we
1119
+ # need to fetch it from the `DataArray`.
1120
+ etp_subarray_data = utils_arrays.get_etp_data_array_from_numpy(
1121
+ data[slices]
1122
+ ).data
1123
+
1124
+ # Create an asynchronous task to upload a block to the
1125
+ # ETP-server.
1126
+ task = self.send(
1127
+ PutDataSubarrays(
1128
+ data_subarrays={
1129
+ dai.path_in_resource: PutDataSubarraysType(
1130
+ uid=dai,
1131
+ data=etp_subarray_data,
1132
+ starts=starts,
1133
+ counts=counts,
1134
+ ),
1135
+ },
1136
+ ),
1137
+ )
1138
+ tasks.append(task)
1139
+
1140
+ # Upload all blocks.
1141
+ responses = await asyncio.gather(*tasks)
1142
+
1143
+ # Check for successful responses.
1144
+ for response in responses:
1145
+ self.assert_response(response, PutDataSubarraysResponse)
1146
+ assert (
1147
+ len(response.success) == 1
1148
+ and dai.path_in_resource in response.success
1149
+ )
1150
+
1151
+ # Return after uploading all sub arrays.
1152
+ return
1153
+
1154
+ # Convert NumPy data-array to an ETP-transport array.
1155
+ etp_array_data = utils_arrays.get_etp_data_array_from_numpy(data)
1156
+
1157
+ # Pass entire array in one message.
1158
+ response = await self.send(
1159
+ PutDataArrays(
1160
+ data_arrays={
1161
+ dai.path_in_resource: PutDataArraysType(
1162
+ uid=dai,
1163
+ array=etp_array_data,
1164
+ ),
1165
+ }
1166
+ )
1167
+ )
1168
+
1169
+ self.assert_response(response, PutDataArraysResponse)
1170
+ assert len(response.success) == 1 and dai.path_in_resource in response.success
1171
+
782
1172
  async def put_array(
783
1173
  self,
784
1174
  uid: DataArrayIdentifier,
@@ -858,7 +1248,7 @@ class ETPClient(ETPConnection):
858
1248
  counts = np.array(counts).astype(np.int64) # len = 2
859
1249
  ends = starts + counts # len = 2
860
1250
 
861
- slices = tuple(map(lambda se: slice(se[0], se[1]), zip(starts, ends)))
1251
+ slices = tuple(map(lambda s, e: slice(s, e), starts, ends))
862
1252
  dataarray = utils_arrays.get_etp_data_array_from_numpy(data[slices])
863
1253
  payload = PutDataSubarraysType(
864
1254
  uid=uid,
@@ -881,17 +1271,19 @@ class ETPClient(ETPConnection):
881
1271
  assert len(response.success) == 1, "expected one success"
882
1272
  return response.success
883
1273
 
884
- #
885
- # chunked get array - ETP will not chunk response - so we need to do it manually
886
- #
887
-
888
1274
  def _get_chunk_sizes(
889
1275
  self, shape, dtype: np.dtype[T.Any] = np.dtype(np.float32), offset=0
890
1276
  ):
1277
+ warnings.warn(
1278
+ "This function is deprecated and will be removed in a later version of "
1279
+ "pyetp. The replacement is located via the import "
1280
+ "`from pyetp.utils_arrays import get_array_block_sizes`.",
1281
+ DeprecationWarning,
1282
+ stacklevel=2,
1283
+ )
891
1284
  shape = np.array(shape)
892
1285
 
893
1286
  # capsize blocksize
894
- # remove 512 bytes for headers and body
895
1287
  max_items = self.max_array_size / dtype.itemsize
896
1288
  block_size = np.power(max_items, 1.0 / len(shape))
897
1289
  block_size = min(2048, int(block_size // 2) * 2)
@@ -1037,7 +1429,13 @@ class connect:
1037
1429
  open_timeout=None,
1038
1430
  )
1039
1431
 
1040
- self.client = ETPClient(self.ws, timeout=self.timeout)
1432
+ self.client = ETPClient(
1433
+ self.ws,
1434
+ etp_timeout=self.timeout,
1435
+ max_message_size=SETTINGS.MaxWebSocketMessagePayloadSize,
1436
+ application_name=SETTINGS.application_name,
1437
+ application_version=SETTINGS.application_version,
1438
+ )
1041
1439
 
1042
1440
  try:
1043
1441
  await self.client.request_session()
@@ -1052,3 +1450,72 @@ class connect:
1052
1450
  async def __aexit__(self, exc_type, exc: Exception, tb: TracebackType):
1053
1451
  await self.client.close()
1054
1452
  await self.ws.close()
1453
+
1454
+
1455
+ @contextlib.asynccontextmanager
1456
+ async def etp_connect(
1457
+ uri: str,
1458
+ data_partition_id: str | None = None,
1459
+ authorization: str | None = None,
1460
+ etp_timeout: float = 10.0,
1461
+ max_message_size: float = 2**20,
1462
+ ) -> ETPClient:
1463
+ additional_headers = {}
1464
+
1465
+ if authorization is not None:
1466
+ additional_headers["Authorization"] = authorization
1467
+ if data_partition_id is not None:
1468
+ additional_headers["data-partition-id"] = data_partition_id
1469
+
1470
+ subprotocols = ["etp12.energistics.org"]
1471
+
1472
+ async with (
1473
+ websockets.connect(
1474
+ uri=uri,
1475
+ subprotocols=subprotocols,
1476
+ max_size=max_message_size,
1477
+ additional_headers=additional_headers,
1478
+ ) as ws,
1479
+ ETPClient(
1480
+ ws=ws,
1481
+ etp_timeout=etp_timeout,
1482
+ max_message_size=max_message_size,
1483
+ ) as etp_client,
1484
+ ):
1485
+ yield etp_client
1486
+
1487
+
1488
+ async def etp_persistent_connect(
1489
+ uri: str,
1490
+ data_partition_id: str | None = None,
1491
+ authorization: str | None = None,
1492
+ etp_timeout: float = 10.0,
1493
+ max_message_size: float = 2**20,
1494
+ ) -> AsyncGenerator[ETPClient]:
1495
+ additional_headers = {}
1496
+
1497
+ if authorization is not None:
1498
+ additional_headers["Authorization"] = authorization
1499
+ if data_partition_id is not None:
1500
+ additional_headers["data-partition-id"] = data_partition_id
1501
+
1502
+ subprotocols = ["etp12.energistics.org"]
1503
+ async for ws in websockets.connect(
1504
+ uri=uri,
1505
+ subprotocols=subprotocols,
1506
+ max_size=max_message_size,
1507
+ additional_headers=additional_headers,
1508
+ ):
1509
+ try:
1510
+ async with ETPClient(
1511
+ ws=ws,
1512
+ etp_timeout=etp_timeout,
1513
+ max_message_size=max_message_size,
1514
+ ) as etp_client:
1515
+ yield etp_client
1516
+ except websockets.ConnectionClosed as e:
1517
+ logger.info(
1518
+ f"Websockets connection closed with message '{e}'. Starting new "
1519
+ "connection"
1520
+ )
1521
+ continue
pyetp/utils_arrays.py CHANGED
@@ -1,3 +1,4 @@
1
+ import sys
1
2
  import typing as T
2
3
 
3
4
  import numpy as np
@@ -15,9 +16,6 @@ from etptypes.energistics.etp.v12.datatypes.array_of_long import ArrayOfLong
15
16
  from etptypes.energistics.etp.v12.datatypes.data_array_types.data_array import (
16
17
  DataArray,
17
18
  )
18
- from etptypes.energistics.etp.v12.datatypes.data_array_types.data_array_metadata import (
19
- DataArrayMetadata,
20
- )
21
19
 
22
20
  SUPPORTED_ARRAY_TYPES: T.TypeAlias = (
23
21
  ArrayOfFloat | ArrayOfBoolean | ArrayOfInt | ArrayOfLong | ArrayOfDouble
@@ -51,6 +49,14 @@ _ANY_LOGICAL_ARRAY_TYPE_MAP: dict[npt.DTypeLike, AnyLogicalArrayType] = {
51
49
  np.dtype(">f8"): AnyLogicalArrayType.ARRAY_OF_DOUBLE64_BE,
52
50
  }
53
51
 
52
+ valid_logical_array_dtypes = list(_ANY_LOGICAL_ARRAY_TYPE_MAP)
53
+ if (sys.version_info.major, sys.version_info.minor) == (3, 10):
54
+ LogicalArrayDTypes: T.TypeAlias = T.Union[
55
+ tuple(v.type for v in valid_logical_array_dtypes)
56
+ ]
57
+ else:
58
+ LogicalArrayDTypes: T.TypeAlias = T.Union[tuple(valid_logical_array_dtypes)]
59
+
54
60
  _INV_ANY_LOGICAL_ARRAY_TYPE_MAP: dict[AnyLogicalArrayType, npt.DTypeLike] = {
55
61
  v: k for k, v in _ANY_LOGICAL_ARRAY_TYPE_MAP.items()
56
62
  }
@@ -91,7 +97,7 @@ _ANY_ARRAY_TYPE_MAP: dict[npt.DTypeLike, AnyArrayType] = {
91
97
  np.dtype("<f4"): AnyArrayType.ARRAY_OF_FLOAT,
92
98
  np.dtype("<f8"): AnyArrayType.ARRAY_OF_DOUBLE,
93
99
  }
94
- valid_dtypes = list(_ANY_ARRAY_TYPE_MAP)
100
+ valid_any_array_dtypes = list(_ANY_ARRAY_TYPE_MAP)
95
101
 
96
102
  _INV_ANY_ARRAY_TYPE_MAP: dict[AnyArrayType, npt.DTypeLike] = {
97
103
  AnyArrayType.ARRAY_OF_BOOLEAN: np.dtype(np.bool_),
@@ -121,7 +127,7 @@ _INV_ANY_ARRAY_MAP: dict[SUPPORTED_ARRAY_TYPES, AnyArrayType] = {
121
127
 
122
128
 
123
129
  def check_if_array_is_valid_dtype(array: npt.NDArray[T.Any]) -> bool:
124
- return array.dtype in valid_dtypes
130
+ return array.dtype in valid_any_array_dtypes
125
131
 
126
132
 
127
133
  def get_valid_dtype_cast(array: npt.NDArray[T.Any]) -> npt.DTypeLike:
@@ -210,9 +216,11 @@ def get_etp_data_array_from_numpy(data: npt.NDArray) -> DataArray:
210
216
  )
211
217
 
212
218
 
213
- def get_transport_array_size(metadata: DataArrayMetadata) -> int:
214
- dtype = _INV_ANY_ARRAY_TYPE_MAP[metadata.transport_array_type]
215
- return int(np.prod(metadata.dimensions) * dtype.itemsize)
219
+ def get_transport_array_size(
220
+ transport_array_type: AnyArrayType, dimensions: list[int] | tuple[int]
221
+ ) -> int:
222
+ dtype = _INV_ANY_ARRAY_TYPE_MAP[transport_array_type]
223
+ return int(np.prod(dimensions) * dtype.itemsize)
216
224
 
217
225
 
218
226
  def get_dtype_from_any_array_class(cls: AnyArray) -> npt.DTypeLike:
@@ -235,11 +243,18 @@ def get_dtype_from_any_array_class(cls: AnyArray) -> npt.DTypeLike:
235
243
  raise TypeError(f"Class {cls} is not a valid array class")
236
244
 
237
245
 
238
- def get_dtype_from_any_array_type(_type: T.Union[AnyArrayType | str]) -> npt.DTypeLike:
246
+ def get_dtype_from_any_array_type(_type: AnyArrayType | str) -> npt.DTypeLike:
239
247
  enum_name = AnyArrayType(_type)
240
248
  return _INV_ANY_ARRAY_TYPE_MAP[enum_name]
241
249
 
242
250
 
251
+ def get_dtype_from_any_logical_array_type(
252
+ _type: AnyLogicalArrayType | str,
253
+ ) -> npt.DTypeLike:
254
+ enum_name = AnyLogicalArrayType(_type)
255
+ return _INV_ANY_LOGICAL_ARRAY_TYPE_MAP[enum_name]
256
+
257
+
243
258
  def get_numpy_array_from_etp_data_array(
244
259
  data_array: DataArray,
245
260
  ) -> npt.NDArray[
@@ -258,3 +273,67 @@ def get_numpy_array_from_etp_data_array(
258
273
  return np.array(data_array.data.item.values, dtype=dtype).reshape(
259
274
  data_array.dimensions
260
275
  )
276
+
277
+
278
+ def get_array_block_sizes(
279
+ shape: tuple[int], dtype: npt.DTypeLike, max_array_size: int
280
+ ) -> tuple[list[list[int]], list[list[int]]]:
281
+ # Total size of array in bytes.
282
+ array_size = int(np.prod(shape) * dtype.itemsize)
283
+ # Calculate the minimum number of blocks needed (if the array was flat).
284
+ num_blocks = int(np.ceil(array_size / max_array_size))
285
+
286
+ # Check if we can split on the first axis.
287
+ if num_blocks > shape[0]:
288
+ assert len(shape) > 1
289
+ # Recursively get block sizes on higher axes.
290
+ starts, counts = get_array_block_sizes(shape[1:], dtype, max_array_size)
291
+ # Repeat starts and counts from higher axes for each axis 0.
292
+ starts = [[i] + s for i in range(shape[0]) for s in starts]
293
+ counts = [[1] + c for i in range(shape[0]) for c in counts]
294
+
295
+ return starts, counts
296
+
297
+ # Count the number of axis elements (e.g., rows for a 2d-array) in each
298
+ # block, and count the number of blocks that remain.
299
+ num_elements_in_block, num_remainder = divmod(shape[0], num_blocks)
300
+
301
+ # Get the number of extra blocks needed to fill in the remaining elements.
302
+ num_extra_blocks = num_remainder // num_elements_in_block + int(
303
+ num_remainder % num_elements_in_block > 0
304
+ )
305
+ # Count the number of elements in the last block.
306
+ num_elements_in_last_block = num_remainder % num_elements_in_block
307
+ # Increase the number of blocks to fit the remaining elements.
308
+ num_blocks += num_extra_blocks
309
+
310
+ # Verify that we still have more axis elements than blocks.
311
+ assert num_blocks <= shape[0]
312
+
313
+ # Set up the number of axis elements in each block.
314
+ axis_counts = np.ones(num_blocks, dtype=int) * num_elements_in_block
315
+ if num_elements_in_last_block > 0:
316
+ assert num_elements_in_last_block < num_elements_in_block
317
+ # Alter the last block with the remaining number of elements.
318
+ axis_counts[-1] = num_elements_in_last_block
319
+
320
+ # Create an array with starting indices for each block and a corresponding
321
+ # array with the number of elements in each block.
322
+ starts = np.zeros((num_blocks, len(shape)), dtype=int)
323
+ counts = np.zeros_like(starts)
324
+
325
+ # Sum up the number of element counts to get the starting index in each
326
+ # block (starting at 0).
327
+ starts[1:, 0] = np.cumsum(axis_counts[:-1])
328
+
329
+ # The axis_counts already lists the number of elements in the first axis,
330
+ # so we only add the length of each remaining axis as the counts for the
331
+ # last axes.
332
+ counts[:, 0] = axis_counts
333
+ counts[:, 1:] = shape[1:]
334
+
335
+ # Check that no block exceeds the maximum size.
336
+ count_size = np.prod(counts, axis=1) * dtype.itemsize
337
+ assert np.all(count_size - max_array_size <= 0)
338
+
339
+ return starts.tolist(), counts.tolist()
@@ -1,20 +1,21 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyetp
3
- Version: 0.0.43
3
+ Version: 0.0.44
4
4
  Summary: Interface with OSDU RDDMS using ETP protocol
5
5
  Author-email: Adam Cheng <52572642+adamchengtkc@users.noreply.github.com>
6
6
  License-Expression: Apache-2.0
7
7
  Project-URL: homepage, https://github.com/equinor/pyetp
8
8
  Classifier: Development Status :: 3 - Alpha
9
+ Requires-Python: >=3.10
9
10
  Description-Content-Type: text/markdown
10
11
  License-File: LICENSE.md
11
- Requires-Dist: numpy>=1.26.3
12
- Requires-Dist: websockets>=12.0
12
+ Requires-Dist: numpy>=2.0
13
+ Requires-Dist: websockets>=15.0
13
14
  Requires-Dist: lxml>=4.9.4
14
15
  Requires-Dist: pydantic>=1.10
15
16
  Requires-Dist: async-timeout>=5.0
16
17
  Requires-Dist: xtgeo>=4.0.0
17
- Requires-Dist: xsdata>=24.3.1
18
+ Requires-Dist: xsdata>=25.4
18
19
  Requires-Dist: etpproto>=1.0.7
19
20
  Dynamic: license-file
20
21
 
@@ -29,6 +30,7 @@ This package is published to PyPI, and can be installed via:
29
30
  ```bash
30
31
  pip install pyetp
31
32
  ```
33
+ The library is tested against Python versions 3.10, 3.11, 3.12 and 3.13.
32
34
 
33
35
  ## Local development
34
36
  Locally we suggest setting up a virtual environment, and installing the latest
@@ -1,13 +1,13 @@
1
- pyetp/__init__.py,sha256=Vu3_qz0AazlD4Q6ZLGdQVqZ7lhzqNLivpQRS3owD8SY,251
2
- pyetp/_version.py,sha256=qJphZkKjg5qeGMzBrtqKwyJZQtcx3oaEh-0t6ejMuVo,706
3
- pyetp/client.py,sha256=kZsPi2OPvbQKX6IVoAsrdZ9_ZDbZYEhHuHOhfbf3XxQ,37427
1
+ pyetp/__init__.py,sha256=_mJRvg6XJgzHeRWd_AsBNonvWUBD3jroZ9MLehPRHP4,337
2
+ pyetp/_version.py,sha256=Zrt00MLeXbWP2ZsKSWktjZjWuHHC_-KLVmAWLgE0o-g,706
3
+ pyetp/client.py,sha256=BF-_sm9j_SbJsQZ7tK8kp5Ekl7FsNtMHAZFJAwHuTMQ,54475
4
4
  pyetp/config.py,sha256=uGEx6n-YF7Rtgwckf0ovRKNOKgCUsiQx4IA0Tyiqafk,870
5
5
  pyetp/resqml_objects.py,sha256=j00e8scSh-yYv4Lpp9WjzLiaKqJWd5Cs7ROcJcxxFVw,50721
6
6
  pyetp/types.py,sha256=zOfUzEQcgBvveWJyM6dD7U3xJ4SCkWElewNL0Ml-PPY,6761
7
7
  pyetp/uri.py,sha256=Y05vuTO-XOurgDavBeoPGOidALoCKjCBIb7YHmfbAco,3115
8
- pyetp/utils_arrays.py,sha256=bLrb8H8TMjbpNLo0Zb9xYRUnM2W4BOf7WxOxWK4fcTc,10218
8
+ pyetp/utils_arrays.py,sha256=BJMiwx0mnSUjAOARCfCvWr_rz6N38e7kSWi24I1bCLA,13447
9
9
  pyetp/utils_xml.py,sha256=i11Zv2PuxW-ejEXxfGPYHvuJv7EHPYXwvgoe1U8fOUM,6826
10
- pyetp-0.0.43.dist-info/licenses/LICENSE.md,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ pyetp-0.0.44.dist-info/licenses/LICENSE.md,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
11
  resqml_objects/__init__.py,sha256=ilimTrQScFryxHrfZwZURVpW0li19bVnayUcHR7S_Fs,183
12
12
  resqml_objects/epc_readers.py,sha256=InYMlwjiZZRG9poQlWKFEOUNJmowGahXoNu_GniOaBw,3458
13
13
  resqml_objects/parsers.py,sha256=UZdBi3QmBq4ejwKI9Dse_lMqL5Bx2s3QNTs8TKS4fO0,427
@@ -15,7 +15,7 @@ resqml_objects/serializers.py,sha256=JqRO6D6ExT5DrVyiNwBgNW28108f6spvxiVqJU0D9mc
15
15
  resqml_objects/v201/__init__.py,sha256=yL3jWgkyGAzr-wt_WJDV_eARE75NoA6SPEKBrKM4Crk,51630
16
16
  resqml_objects/v201/generated.py,sha256=Se0eePS6w25sfmnp2UBSkgzDGJ9c9Y2QqJgDRUTt_-Q,769527
17
17
  resqml_objects/v201/utils.py,sha256=WiywauiJRBWhdjUvbKhpltRjoBX3qWd7qQ0_FAmIzUc,1442
18
- pyetp-0.0.43.dist-info/METADATA,sha256=nFszNnd-08BucrOKHaFZaB-kMiFz_vqxk6QTWdl04bU,2587
19
- pyetp-0.0.43.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- pyetp-0.0.43.dist-info/top_level.txt,sha256=NrdXbidkT5QR4NjH6nv2Frixknqse3jZq7bnqNdVb5k,21
21
- pyetp-0.0.43.dist-info/RECORD,,
18
+ pyetp-0.0.44.dist-info/METADATA,sha256=51q_BHQ-1OSNjt021B3R8zveiDViTHWapaaSrgYVy6U,2679
19
+ pyetp-0.0.44.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ pyetp-0.0.44.dist-info/top_level.txt,sha256=NrdXbidkT5QR4NjH6nv2Frixknqse3jZq7bnqNdVb5k,21
21
+ pyetp-0.0.44.dist-info/RECORD,,
File without changes