frogml-core 0.0.72__py3-none-any.whl → 0.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,231 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
5
+ from frogml_proto.qwak.service_discovery import service_discovery_location_service_pb2 as qwak_dot_service__discovery_dot_service__discovery__location__service__pb2
6
+
7
+
8
+ class LocationDiscoveryServiceStub(object):
9
+ """Missing associated documentation comment in .proto file."""
10
+
11
+ def __init__(self, channel):
12
+ """Constructor.
13
+
14
+ Args:
15
+ channel: A grpc.Channel.
16
+ """
17
+ self.GetOfflineServingUrl = channel.unary_unary(
18
+ '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetOfflineServingUrl',
19
+ request_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
20
+ response_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
21
+ )
22
+ self.GetDistributionManagerUrl = channel.unary_unary(
23
+ '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetDistributionManagerUrl',
24
+ request_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
25
+ response_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
26
+ )
27
+ self.GetAnalyticsEngineUrl = channel.unary_unary(
28
+ '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetAnalyticsEngineUrl',
29
+ request_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
30
+ response_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
31
+ )
32
+ self.GetMetricsGatewayUrl = channel.unary_unary(
33
+ '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetMetricsGatewayUrl',
34
+ request_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
35
+ response_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
36
+ )
37
+ self.GetFeaturesOperatorUrl = channel.unary_unary(
38
+ '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetFeaturesOperatorUrl',
39
+ request_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
40
+ response_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
41
+ )
42
+ self.GetHostingGatewayUrl = channel.unary_unary(
43
+ '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetHostingGatewayUrl',
44
+ request_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
45
+ response_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
46
+ )
47
+
48
+
49
+ class LocationDiscoveryServiceServicer(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ def GetOfflineServingUrl(self, request, context):
53
+ """Missing associated documentation comment in .proto file."""
54
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
55
+ context.set_details('Method not implemented!')
56
+ raise NotImplementedError('Method not implemented!')
57
+
58
+ def GetDistributionManagerUrl(self, request, context):
59
+ """Missing associated documentation comment in .proto file."""
60
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
61
+ context.set_details('Method not implemented!')
62
+ raise NotImplementedError('Method not implemented!')
63
+
64
+ def GetAnalyticsEngineUrl(self, request, context):
65
+ """Missing associated documentation comment in .proto file."""
66
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
67
+ context.set_details('Method not implemented!')
68
+ raise NotImplementedError('Method not implemented!')
69
+
70
+ def GetMetricsGatewayUrl(self, request, context):
71
+ """Missing associated documentation comment in .proto file."""
72
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
73
+ context.set_details('Method not implemented!')
74
+ raise NotImplementedError('Method not implemented!')
75
+
76
+ def GetFeaturesOperatorUrl(self, request, context):
77
+ """Missing associated documentation comment in .proto file."""
78
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
79
+ context.set_details('Method not implemented!')
80
+ raise NotImplementedError('Method not implemented!')
81
+
82
+ def GetHostingGatewayUrl(self, request, context):
83
+ """Missing associated documentation comment in .proto file."""
84
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
85
+ context.set_details('Method not implemented!')
86
+ raise NotImplementedError('Method not implemented!')
87
+
88
+
89
+ def add_LocationDiscoveryServiceServicer_to_server(servicer, server):
90
+ rpc_method_handlers = {
91
+ 'GetOfflineServingUrl': grpc.unary_unary_rpc_method_handler(
92
+ servicer.GetOfflineServingUrl,
93
+ request_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.FromString,
94
+ response_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.SerializeToString,
95
+ ),
96
+ 'GetDistributionManagerUrl': grpc.unary_unary_rpc_method_handler(
97
+ servicer.GetDistributionManagerUrl,
98
+ request_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.FromString,
99
+ response_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.SerializeToString,
100
+ ),
101
+ 'GetAnalyticsEngineUrl': grpc.unary_unary_rpc_method_handler(
102
+ servicer.GetAnalyticsEngineUrl,
103
+ request_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.FromString,
104
+ response_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.SerializeToString,
105
+ ),
106
+ 'GetMetricsGatewayUrl': grpc.unary_unary_rpc_method_handler(
107
+ servicer.GetMetricsGatewayUrl,
108
+ request_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.FromString,
109
+ response_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.SerializeToString,
110
+ ),
111
+ 'GetFeaturesOperatorUrl': grpc.unary_unary_rpc_method_handler(
112
+ servicer.GetFeaturesOperatorUrl,
113
+ request_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.FromString,
114
+ response_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.SerializeToString,
115
+ ),
116
+ 'GetHostingGatewayUrl': grpc.unary_unary_rpc_method_handler(
117
+ servicer.GetHostingGatewayUrl,
118
+ request_deserializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.FromString,
119
+ response_serializer=qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.SerializeToString,
120
+ ),
121
+ }
122
+ generic_handler = grpc.method_handlers_generic_handler(
123
+ 'qwak.service_discovery.location.v0.LocationDiscoveryService', rpc_method_handlers)
124
+ server.add_generic_rpc_handlers((generic_handler,))
125
+
126
+
127
+ # This class is part of an EXPERIMENTAL API.
128
+ class LocationDiscoveryService(object):
129
+ """Missing associated documentation comment in .proto file."""
130
+
131
+ @staticmethod
132
+ def GetOfflineServingUrl(request,
133
+ target,
134
+ options=(),
135
+ channel_credentials=None,
136
+ call_credentials=None,
137
+ insecure=False,
138
+ compression=None,
139
+ wait_for_ready=None,
140
+ timeout=None,
141
+ metadata=None):
142
+ return grpc.experimental.unary_unary(request, target, '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetOfflineServingUrl',
143
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
144
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
145
+ options, channel_credentials,
146
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
147
+
148
+ @staticmethod
149
+ def GetDistributionManagerUrl(request,
150
+ target,
151
+ options=(),
152
+ channel_credentials=None,
153
+ call_credentials=None,
154
+ insecure=False,
155
+ compression=None,
156
+ wait_for_ready=None,
157
+ timeout=None,
158
+ metadata=None):
159
+ return grpc.experimental.unary_unary(request, target, '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetDistributionManagerUrl',
160
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
161
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
162
+ options, channel_credentials,
163
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
164
+
165
+ @staticmethod
166
+ def GetAnalyticsEngineUrl(request,
167
+ target,
168
+ options=(),
169
+ channel_credentials=None,
170
+ call_credentials=None,
171
+ insecure=False,
172
+ compression=None,
173
+ wait_for_ready=None,
174
+ timeout=None,
175
+ metadata=None):
176
+ return grpc.experimental.unary_unary(request, target, '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetAnalyticsEngineUrl',
177
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
178
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
179
+ options, channel_credentials,
180
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
181
+
182
+ @staticmethod
183
+ def GetMetricsGatewayUrl(request,
184
+ target,
185
+ options=(),
186
+ channel_credentials=None,
187
+ call_credentials=None,
188
+ insecure=False,
189
+ compression=None,
190
+ wait_for_ready=None,
191
+ timeout=None,
192
+ metadata=None):
193
+ return grpc.experimental.unary_unary(request, target, '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetMetricsGatewayUrl',
194
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
195
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
196
+ options, channel_credentials,
197
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
198
+
199
+ @staticmethod
200
+ def GetFeaturesOperatorUrl(request,
201
+ target,
202
+ options=(),
203
+ channel_credentials=None,
204
+ call_credentials=None,
205
+ insecure=False,
206
+ compression=None,
207
+ wait_for_ready=None,
208
+ timeout=None,
209
+ metadata=None):
210
+ return grpc.experimental.unary_unary(request, target, '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetFeaturesOperatorUrl',
211
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
212
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
213
+ options, channel_credentials,
214
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
215
+
216
+ @staticmethod
217
+ def GetHostingGatewayUrl(request,
218
+ target,
219
+ options=(),
220
+ channel_credentials=None,
221
+ call_credentials=None,
222
+ insecure=False,
223
+ compression=None,
224
+ wait_for_ready=None,
225
+ timeout=None,
226
+ metadata=None):
227
+ return grpc.experimental.unary_unary(request, target, '/qwak.service_discovery.location.v0.LocationDiscoveryService/GetHostingGatewayUrl',
228
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequest.SerializeToString,
229
+ qwak_dot_service__discovery_dot_service__discovery__location__service__pb2.GetServingUrlRequestResponse.FromString,
230
+ options, channel_credentials,
231
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -64,6 +64,9 @@ from frogml_services_mock.mocks.job_registry_service_api import (
64
64
  from frogml_services_mock.mocks.kube_captain_service_api import (
65
65
  KubeCaptainServiceApiMock,
66
66
  )
67
+ from frogml_services_mock.mocks.location_discovery_service_api import (
68
+ LocationDiscoveryServiceApiMock,
69
+ )
67
70
  from frogml_services_mock.mocks.logging_service import LoggingServiceApiMock
68
71
  from frogml_services_mock.mocks.model_management_service import (
69
72
  ModelsManagementServiceMock,
@@ -131,3 +134,4 @@ class FrogmlMocks:
131
134
  prompt_manager_service: PromptManagerServiceMock
132
135
  model_version_manager_service: ModelVersionManagerServiceMock
133
136
  repository_service: RepositoryServiceMock
137
+ location_discovery_service: LocationDiscoveryServiceApiMock
@@ -0,0 +1,104 @@
1
+ from typing import Dict, Optional
2
+
3
+ import grpc
4
+ from frogml_proto.qwak.service_discovery.service_discovery_location_pb2 import (
5
+ ServiceLocationDescriptor,
6
+ )
7
+ from frogml_proto.qwak.service_discovery.service_discovery_location_service_pb2 import (
8
+ GetServingUrlRequestResponse,
9
+ )
10
+ from frogml_proto.qwak.service_discovery.service_discovery_location_service_pb2_grpc import (
11
+ LocationDiscoveryServiceServicer,
12
+ )
13
+
14
+
15
+ class LocationDiscoveryServiceApiMock(LocationDiscoveryServiceServicer):
16
+ """
17
+ Mock implementation of the LocationDiscoveryService for testing SDK behavior.
18
+ Allows setting mock responses and optional error codes for each endpoint.
19
+ """
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+ self._responses: Dict[str, Optional[ServiceLocationDescriptor]] = {}
24
+ self._error_codes: Dict[str, grpc.StatusCode] = {}
25
+
26
+ def _set_mock(
27
+ self,
28
+ key: str,
29
+ response: Optional[ServiceLocationDescriptor],
30
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
31
+ ):
32
+ self._responses[key] = response
33
+ self._error_codes[key] = error_code
34
+
35
+ def _handle(
36
+ self, key: str, context: grpc.ServicerContext
37
+ ) -> GetServingUrlRequestResponse:
38
+ response = self._responses.get(key)
39
+ if response:
40
+ return GetServingUrlRequestResponse(location=response)
41
+ context.set_code(self._error_codes.get(key, grpc.StatusCode.NOT_FOUND))
42
+ context.set_details(f"No mock response set for {key}")
43
+ return GetServingUrlRequestResponse()
44
+
45
+ # Setters
46
+ def set_get_offline_serving_url_response(
47
+ self,
48
+ response: Optional[ServiceLocationDescriptor],
49
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
50
+ ):
51
+ self._set_mock("offline", response, error_code)
52
+
53
+ def set_get_distribution_manager_url_response(
54
+ self,
55
+ response: Optional[ServiceLocationDescriptor],
56
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
57
+ ):
58
+ self._set_mock("distribution", response, error_code)
59
+
60
+ def set_get_analytics_engine_url_response(
61
+ self,
62
+ response: Optional[ServiceLocationDescriptor],
63
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
64
+ ):
65
+ self._set_mock("analytics", response, error_code)
66
+
67
+ def set_get_metrics_gateway_url_response(
68
+ self,
69
+ response: Optional[ServiceLocationDescriptor],
70
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
71
+ ):
72
+ self._set_mock("metrics", response, error_code)
73
+
74
+ def set_get_features_operator_url_response(
75
+ self,
76
+ response: Optional[ServiceLocationDescriptor],
77
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
78
+ ):
79
+ self._set_mock("features", response, error_code)
80
+
81
+ def set_get_hosting_gateway_url_response(
82
+ self,
83
+ response: Optional[ServiceLocationDescriptor],
84
+ error_code: grpc.StatusCode = grpc.StatusCode.NOT_FOUND,
85
+ ):
86
+ self._set_mock("hosting", response, error_code)
87
+
88
+ def GetOfflineServingUrl(self, request, context):
89
+ return self._handle("offline", context)
90
+
91
+ def GetDistributionManagerUrl(self, request, context):
92
+ return self._handle("distribution", context)
93
+
94
+ def GetAnalyticsEngineUrl(self, request, context):
95
+ return self._handle("analytics", context)
96
+
97
+ def GetMetricsGatewayUrl(self, request, context):
98
+ return self._handle("metrics", context)
99
+
100
+ def GetFeaturesOperatorUrl(self, request, context):
101
+ return self._handle("features", context)
102
+
103
+ def GetHostingGatewayUrl(self, request, context):
104
+ return self._handle("hosting", context)
@@ -3,7 +3,7 @@ from typing import Any, Generator, List, Tuple
3
3
 
4
4
  import grpc
5
5
  import pytest
6
-
6
+ from frogml_core.inner.di_configuration import FrogmlContainer
7
7
  from frogml_proto.jfml.model_version.v1.model_version_manager_service_pb2_grpc import (
8
8
  add_ModelVersionManagerServiceServicer_to_server,
9
9
  )
@@ -113,6 +113,9 @@ from frogml_proto.qwak.secret_service.secret_service_pb2_grpc import (
113
113
  from frogml_proto.qwak.self_service.user.v1.user_service_pb2_grpc import (
114
114
  add_UserServiceServicer_to_server,
115
115
  )
116
+ from frogml_proto.qwak.service_discovery.service_discovery_location_service_pb2_grpc import (
117
+ add_LocationDiscoveryServiceServicer_to_server,
118
+ )
116
119
  from frogml_proto.qwak.vectors.v1.collection.collection_service_pb2_grpc import (
117
120
  add_VectorCollectionServiceServicer_to_server,
118
121
  )
@@ -122,7 +125,6 @@ from frogml_proto.qwak.vectors.v1.vector_service_pb2_grpc import (
122
125
  from frogml_proto.qwak.workspace.workspace_service_pb2_grpc import (
123
126
  add_WorkspaceManagementServiceServicer_to_server,
124
127
  )
125
- from frogml_core.inner.di_configuration import FrogmlContainer
126
128
  from frogml_services_mock.mocks.alert_manager_service_api import (
127
129
  AlertManagerServiceApiMock,
128
130
  )
@@ -187,6 +189,9 @@ from frogml_services_mock.mocks.job_registry_service_api import (
187
189
  from frogml_services_mock.mocks.kube_captain_service_api import (
188
190
  KubeCaptainServiceApiMock,
189
191
  )
192
+ from frogml_services_mock.mocks.location_discovery_service_api import (
193
+ LocationDiscoveryServiceApiMock,
194
+ )
190
195
  from frogml_services_mock.mocks.logging_service import LoggingServiceApiMock
191
196
  from frogml_services_mock.mocks.model_management_service import (
192
197
  ModelsManagementServiceMock,
@@ -249,16 +254,17 @@ def frogml_container():
249
254
  feature_store,
250
255
  file_versioning,
251
256
  instance_template,
257
+ jfrog_gateway,
252
258
  kube_deployment_captain,
259
+ location_discovery,
253
260
  logging_client,
254
261
  model_management,
262
+ model_version_manager,
255
263
  project,
256
264
  secret_service,
257
265
  user_application_instance,
258
266
  vector_store,
259
267
  workspace_manager,
260
- model_version_manager,
261
- jfrog_gateway,
262
268
  )
263
269
  from frogml_core.clients.administration import (
264
270
  authentication,
@@ -303,6 +309,7 @@ def frogml_container():
303
309
  prompt_manager_client,
304
310
  model_version_manager,
305
311
  jfrog_gateway,
312
+ location_discovery,
306
313
  ]
307
314
  )
308
315
 
@@ -527,6 +534,11 @@ def attach_servicers(free_port, server):
527
534
  RepositoryServiceMock,
528
535
  add_RepositoryServiceServicer_to_server,
529
536
  ),
537
+ (
538
+ "location_discovery_service",
539
+ LocationDiscoveryServiceApiMock,
540
+ add_LocationDiscoveryServiceServicer_to_server,
541
+ ),
530
542
  ("port", free_port, None),
531
543
  ],
532
544
  )
@@ -1,32 +0,0 @@
1
- import uuid
2
- from abc import ABC, abstractmethod
3
-
4
- import pandas as pd
5
-
6
-
7
- class BaseQueryEngine(ABC):
8
- @abstractmethod
9
- def upload_table(self, df: pd.DataFrame):
10
- pass
11
-
12
- @abstractmethod
13
- def cleanup(self):
14
- pass
15
-
16
- @abstractmethod
17
- def run_query(self, query: str):
18
- pass
19
-
20
- @abstractmethod
21
- def read_pandas_from_query(self, query: str, parse_dates=None):
22
- pass
23
-
24
- @staticmethod
25
- @abstractmethod
26
- def get_quotes():
27
- pass
28
-
29
- class JoinTableSpec:
30
- def __init__(self, join_tables_db_name: str, quotes: str):
31
- self.table_name = str(uuid.uuid4())
32
- self.join_table_full_path = f"{quotes}{join_tables_db_name}{quotes}.{quotes}{self.table_name}{quotes}"
File without changes
@@ -1,154 +0,0 @@
1
- import time
2
- import uuid
3
-
4
- import pandas as pd
5
- from google.protobuf.duration_pb2 import Duration
6
-
7
- from frogml_proto.qwak.ecosystem.v0.ecosystem_runtime_service_pb2 import (
8
- GetCloudCredentialsParameters,
9
- GetCloudCredentialsRequest,
10
- OfflineFeatureStoreClient,
11
- PermissionSet,
12
- )
13
- from frogml_core.clients.administration.eco_system.client import EcosystemClient
14
- from frogml_core.exceptions import FrogmlException
15
- from frogml_core.feature_store.offline._query_engine import BaseQueryEngine
16
-
17
- RECONNECT_THRESHOLD_SEC = 300
18
-
19
-
20
- class AthenaQueryEngine(BaseQueryEngine):
21
- def __init__(self):
22
- eco_client = EcosystemClient()
23
- self.bucket, environment_id = self._get_env_details(eco_client)
24
-
25
- self.staging_folder_prefix = (
26
- f"{environment_id}/tmp/offline_fs/{str(uuid.uuid4())}" # nosec B108
27
- )
28
- self.temp_join_table_base_folder = (
29
- f"s3://{self.bucket}/{self.staging_folder_prefix}"
30
- )
31
-
32
- self.conn, self.expiration_time = self._init_connection()
33
- self.cursor = self.conn.cursor()
34
-
35
- self.join_table_specs = []
36
-
37
- self.join_tables_db_name = f"qwak_temp_data_{environment_id.replace('-', '_')}"
38
- self.cursor.execute(f"CREATE DATABASE IF NOT EXISTS {self.join_tables_db_name}")
39
-
40
- @staticmethod
41
- def _get_env_details(eco_client):
42
- environment_configuration = eco_client.get_environment_configuration()
43
-
44
- return (
45
- environment_configuration.configuration.object_storage_bucket,
46
- environment_configuration.id,
47
- )
48
-
49
- def _init_connection(self):
50
- try:
51
- # obtain credentials through STS
52
- eco_client = EcosystemClient()
53
- cloud_credentials_response = eco_client.get_cloud_credentials(
54
- request=GetCloudCredentialsRequest(
55
- parameters=GetCloudCredentialsParameters(
56
- duration=Duration(seconds=60 * 60, nanos=0),
57
- permission_set=PermissionSet(
58
- offline_feature_store_client=OfflineFeatureStoreClient()
59
- ),
60
- )
61
- )
62
- )
63
-
64
- aws_credentials = (
65
- cloud_credentials_response.cloud_credentials.aws_temporary_credentials
66
- )
67
-
68
- try:
69
- from pyathena import connect
70
- from pyathena.pandas.cursor import PandasCursor
71
- except ImportError:
72
- raise FrogmlException(
73
- """
74
- Missing 'pyathena' dependency required for fetching data from the offline store.
75
- Please pip install pyathena
76
- """
77
- )
78
-
79
- conn = connect(
80
- s3_staging_dir=self.temp_join_table_base_folder,
81
- aws_access_key_id=aws_credentials.access_key_id,
82
- aws_secret_access_key=aws_credentials.secret_access_key,
83
- aws_session_token=aws_credentials.session_token,
84
- region_name=aws_credentials.region,
85
- cursor_class=PandasCursor,
86
- )
87
-
88
- return (
89
- conn,
90
- aws_credentials.expiration_time.seconds,
91
- )
92
-
93
- except FrogmlException as e:
94
- raise e
95
-
96
- except Exception as e:
97
- raise FrogmlException(
98
- f"Got an error trying to retrieve credentials to query the offline store "
99
- f"in the cloud, error is: {e}"
100
- )
101
-
102
- def upload_table(self, df: pd.DataFrame):
103
- join_table_spec = super().JoinTableSpec(
104
- self.join_tables_db_name, AthenaQueryEngine.get_quotes()
105
- )
106
- self.join_table_specs.append(join_table_spec)
107
-
108
- from pyathena.pandas.util import to_sql
109
-
110
- to_sql(
111
- df,
112
- join_table_spec.table_name,
113
- self.conn,
114
- f"{self.temp_join_table_base_folder}/{join_table_spec.table_name}/",
115
- schema=self.join_tables_db_name,
116
- index=False,
117
- if_exists="replace",
118
- )
119
-
120
- return join_table_spec.join_table_full_path
121
-
122
- def run_query(self, query: str):
123
- self._check_reconnection()
124
- return self.cursor.execute(query).fetchall()
125
-
126
- def read_pandas_from_query(self, query: str, parse_dates=None):
127
- self._check_reconnection()
128
- return pd.read_sql(
129
- query,
130
- self.conn,
131
- parse_dates=parse_dates,
132
- )
133
-
134
- def _check_reconnection(self):
135
- if self.expiration_time - time.time() < RECONNECT_THRESHOLD_SEC:
136
- self.conn, self.expiration_time = self._init_connection()
137
- self.cursor = self.conn.cursor()
138
-
139
- def cleanup(self):
140
- self._check_reconnection()
141
- for join_table_spec in self.join_table_specs:
142
- self.cursor.execute(
143
- f"""DROP TABLE {join_table_spec.join_table_full_path.replace('"', '`')}"""
144
- )
145
-
146
- self.join_table_specs = []
147
-
148
- s3 = self.conn.session.resource("s3")
149
- bucket = s3.Bucket(self.bucket)
150
- bucket.objects.filter(Prefix=self.staging_folder_prefix).delete()
151
-
152
- @staticmethod
153
- def get_quotes():
154
- return '"'