tairos-data-convert 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,406 @@
1
+ import json
2
+ import time
3
+ import base64
4
+ import hmac
5
+ import hashlib
6
+ from collections import defaultdict
7
+
8
+ from data_pipeline.py_api.raw_api import (
9
+ list_task_compiler,
10
+ list_collection_task,
11
+ list_data,
12
+ list_task_samples,
13
+ list_eval_tasks,
14
+ aggregate_data,
15
+ get_task_sample,
16
+ get_userinfo,
17
+ get_current_user_role,
18
+ send_data,
19
+ get_download_url,
20
+ )
21
+
22
+ DEFAULT_ENDPOINT = "https://roboticsx-data.woa.com"
23
+
24
+
25
+ class RemoteServiceError(Exception):
26
+ """调用远程服务失败"""
27
+ pass
28
+
29
+
30
+ def sample_from_task_compiler_id(id, api_endpoint=DEFAULT_ENDPOINT, headers=None):
31
+ payload = {
32
+ "task_compiler_id": id
33
+ }
34
+ success, result = get_task_sample(payload, api_endpoint, headers)
35
+ if not success:
36
+ raise RemoteServiceError(result)
37
+ return result["task_sample"]
38
+
39
+
40
+ def get_task_sample_by_id(id, api_endpoint=DEFAULT_ENDPOINT):
41
+ query = {
42
+ "id": id
43
+ }
44
+ payload = {
45
+ "query": json.dumps(query)
46
+ }
47
+
48
+ success, result = list_task_samples(payload, api_endpoint)
49
+ if not success:
50
+ raise RemoteServiceError(result)
51
+
52
+ response_json = result
53
+ if "task_samples" not in response_json or len(response_json["task_samples"]) == 0:
54
+ raise ValueError(f"没有找到符合条件的task sample, id: {id}")
55
+
56
+ return response_json["task_samples"][0]
57
+
58
+
59
+ def get_all_data_from_dataset(dataset_id, api_endpoint=DEFAULT_ENDPOINT):
60
+ """
61
+ 获取所有数据
62
+ :param dataset_id: 数据集ID
63
+ :return: 数据
64
+ """
65
+ query = {
66
+ "set_associations.set_id": dataset_id
67
+ }
68
+ payload = {
69
+ "query": {
70
+ "json": json.dumps(query)
71
+ }
72
+ }
73
+
74
+ page = 1
75
+ page_size = 100
76
+ data = []
77
+ while True:
78
+ payload["page_info"] = {
79
+ "page": page,
80
+ "page_size": page_size,
81
+ }
82
+ success, result = list_data(payload, api_endpoint)
83
+ if not success:
84
+ raise RemoteServiceError(result)
85
+
86
+ response_json = result
87
+ if "data" in response_json and "data_list" in response_json["data"]:
88
+ data.extend(response_json["data"]["data_list"])
89
+ else:
90
+ break
91
+
92
+ if len(response_json["data"]["data_list"]) < page_size:
93
+ break
94
+ page += 1
95
+
96
+ return data
97
+
98
+
99
+ def get_all_data_from_eval_task(task_id, api_endpoint=DEFAULT_ENDPOINT):
100
+ """
101
+ 获取所有数据
102
+ :param task_id: 评测任务ID
103
+ :return: 数据
104
+ """
105
+ query = {
106
+ "metadata.collection_task_id": task_id,
107
+ "metadata.collect_info.task_status": {
108
+ "$in": [1, 2]
109
+ }
110
+ }
111
+ payload = {
112
+ "query": {
113
+ "json": json.dumps(query)
114
+ }
115
+ }
116
+
117
+ page = 1
118
+ page_size = 100
119
+ data = []
120
+ while True:
121
+ payload["page_info"] = {
122
+ "page": page,
123
+ "page_size": page_size,
124
+ }
125
+ success, result = list_data(payload, api_endpoint)
126
+ if not success:
127
+ raise RemoteServiceError(result)
128
+
129
+ response_json = result
130
+ if "data" in response_json and "data_list" in response_json["data"]:
131
+ data.extend(response_json["data"]["data_list"])
132
+ else:
133
+ break
134
+
135
+ if len(response_json["data"]["data_list"]) < page_size:
136
+ break
137
+ page += 1
138
+
139
+ return data
140
+
141
+
142
+ def get_task_compiler_by_id(id, api_endpoint=DEFAULT_ENDPOINT):
143
+ query = {
144
+ "id": id
145
+ }
146
+ payload = {
147
+ "query": json.dumps(query),
148
+ "page_info": {
149
+ "page": 1,
150
+ "page_size": 100,
151
+ }
152
+ }
153
+
154
+ success, result = list_task_compiler(payload, api_endpoint)
155
+ if not success:
156
+ raise RemoteServiceError(result)
157
+
158
+ response_json = result
159
+ if "task_compilers" not in response_json or len(response_json["task_compilers"]) == 0:
160
+ raise ValueError(f"没有找到符合条件的task compiler, id: {id}")
161
+
162
+ return response_json["task_compilers"][0]
163
+
164
+
165
+ def get_all_task_compilers(query, api_endpoint=DEFAULT_ENDPOINT, headers=None):
166
+ """
167
+ 获取所有task compiler
168
+ :param query: 查询条件
169
+ :return: task compiler
170
+ """
171
+ payload = {
172
+ "query": json.dumps(query)
173
+ }
174
+
175
+ page = 1
176
+ page_size = 100
177
+ task_compilers = []
178
+
179
+ while True:
180
+ payload["page_info"] = {
181
+ "page": page,
182
+ "page_size": page_size,
183
+ }
184
+ success, result = list_task_compiler(payload, api_endpoint, headers)
185
+ if not success:
186
+ raise RemoteServiceError(result)
187
+
188
+ response_json = result
189
+ if "task_compilers" in response_json:
190
+ task_compilers.extend(response_json["task_compilers"])
191
+
192
+ if "task_compilers" not in response_json or len(response_json["task_compilers"]) < page_size:
193
+ break
194
+ page += 1
195
+
196
+ return task_compilers
197
+
198
+
199
+ def get_all_collection_tasks(query, api_endpoint=DEFAULT_ENDPOINT):
200
+ """
201
+ 获取所有collection task
202
+ :param query: 查询条件
203
+ :return: collection task
204
+ """
205
+ payload = {
206
+ "query": json.dumps(query)
207
+ }
208
+
209
+ page = 1
210
+ page_size = 100
211
+ collection_tasks = []
212
+
213
+ while True:
214
+ payload["page_info"] = {
215
+ "page": page,
216
+ "page_size": page_size,
217
+ }
218
+ success, result = list_collection_task(payload, api_endpoint)
219
+ if not success:
220
+ raise RemoteServiceError(result)
221
+
222
+ response_json = result
223
+ if "collection_tasks" in response_json:
224
+ collection_tasks.extend(response_json["collection_tasks"])
225
+
226
+ if "collection_tasks" not in response_json or len(response_json["collection_tasks"]) < page_size:
227
+ break
228
+ page += 1
229
+
230
+ return collection_tasks
231
+
232
+
233
+ def get_collection_task_by_id(id, api_endpoint=DEFAULT_ENDPOINT, headers=None):
234
+ query = {
235
+ "id": id
236
+ }
237
+ payload = {
238
+ "query": json.dumps(query)
239
+ }
240
+ success, result = list_collection_task(payload, api_endpoint, headers)
241
+ if not success:
242
+ raise RemoteServiceError(result)
243
+
244
+ response_json = result
245
+ if "collection_tasks" not in response_json or len(response_json["collection_tasks"]) == 0:
246
+ raise ValueError(f"没有找到符合条件的collection task, id: {id}, response: {response_json}")
247
+
248
+ return response_json["collection_tasks"][0]
249
+
250
+
251
+ def get_eval_task_by_id(id, api_endpoint=DEFAULT_ENDPOINT, headers=None):
252
+ query = {
253
+ "id": id
254
+ }
255
+ payload = {
256
+ "query": json.dumps(query),
257
+ "page_info": {
258
+ "page": 1,
259
+ "page_size": 1
260
+ }
261
+ }
262
+ success, result = list_eval_tasks(payload, api_endpoint, headers)
263
+ if not success:
264
+ raise RemoteServiceError(result)
265
+
266
+ response_json = result["data"]
267
+ if "eval_tasks" not in response_json or len(response_json["eval_tasks"]) == 0:
268
+ raise ValueError(f"没有找到符合条件的eval task, id: {id}")
269
+
270
+ return response_json["eval_tasks"][0]
271
+
272
+
273
+ def call_aggregate_data(match, group, sort, api_endpoint=DEFAULT_ENDPOINT, headers=None):
274
+ payload = {
275
+ "aggregate_pipeline": [
276
+ json.dumps({"$match": match}),
277
+ json.dumps({"$group": group}),
278
+ json.dumps({"$sort": sort}),
279
+ ]
280
+ }
281
+ success, result = aggregate_data(payload, api_endpoint, headers)
282
+ if not success:
283
+ raise RemoteServiceError(result)
284
+
285
+ return result.get("data", [])
286
+
287
+
288
+ def get_data_count_per_task_compiler_for_eval_task_id(id, api_endpoint=DEFAULT_ENDPOINT, headers=None):
289
+ match = {
290
+ "metadata.collection_task_id": id,
291
+ "metadata.collect_info.task_status": {
292
+ "$in": [1, 2],
293
+ },
294
+ }
295
+ group = {
296
+ "_id": {
297
+ "task_compiler_id": "$metadata.task_compiler_id",
298
+ },
299
+ "count": {
300
+ "$sum": 1,
301
+ },
302
+ }
303
+ sort = {
304
+ "_id": 1,
305
+ }
306
+ results = call_aggregate_data(match, group, sort, api_endpoint, headers)
307
+
308
+ task_compiler_id_to_count = defaultdict(int)
309
+ for result in results:
310
+ task_compiler_id = result.get("group_fields", {}).get("task_compiler_id", "")
311
+ count = result.get("aggregate_fields", {}).get("count", 0)
312
+ task_compiler_id_to_count[task_compiler_id] = count
313
+
314
+ return task_compiler_id_to_count
315
+
316
+
317
+ def get_nickname_and_user_id(api_endpoint=DEFAULT_ENDPOINT, headers=None):
318
+ success, result = get_userinfo(api_endpoint, headers)
319
+ if not success:
320
+ raise RemoteServiceError(result)
321
+ return result.get("user_info", {}).get("nick", ""), result.get("user_info", {}).get("uid", "")
322
+
323
+
324
+ def get_tenant_id(api_endpoint=DEFAULT_ENDPOINT, headers=None):
325
+ success, result = get_current_user_role(api_endpoint, headers)
326
+ if not success:
327
+ raise RemoteServiceError(result)
328
+ return result.get("data", {}).get("tenant_id", "")
329
+
330
+
331
+ def add_data_pipeline_signature_headers(headers, secret_key, hostname):
332
+ nonce = f"{time.time_ns()}{int(time.time())}"
333
+ timestamp = f"{int(time.time())}"
334
+
335
+ message = (
336
+ f"caller_name={hostname}&cmd=2&device_id={hostname}"
337
+ f"&nonce={nonce}&timestamp={timestamp}"
338
+ )
339
+ signature = base64.b64encode(hmac.new(
340
+ key=secret_key.encode(),
341
+ msg=message.encode(),
342
+ digestmod=hashlib.sha256
343
+ ).digest())
344
+
345
+ headers.update({
346
+ "trpc-trans-info": json.dumps({
347
+ "X-Data-Pipeline-Signature": base64.b64encode(signature).decode(),
348
+ "X-Data-Pipeline-Nonce": base64.b64encode(nonce.encode()).decode(),
349
+ "X-Data-Pipeline-Timestamp": base64.b64encode(timestamp.encode()).decode(),
350
+ })
351
+ })
352
+
353
+
354
+ def send_data_with_signature(
355
+ status, hostname, secret_key, task_id, task_compiler_id,
356
+ api_endpoint=DEFAULT_ENDPOINT, headers=None
357
+ ):
358
+ if headers is None:
359
+ headers = {}
360
+ add_data_pipeline_signature_headers(headers, secret_key, hostname)
361
+
362
+ payload = {
363
+ "cmd": 2,
364
+ "caller": {
365
+ "name": hostname
366
+ },
367
+ "event": {
368
+ "event_type": 1,
369
+ "collecting_status": {
370
+ "task_id": task_id,
371
+ "task_compiler_id": task_compiler_id,
372
+ "step": 2,
373
+ "status": status,
374
+ },
375
+ },
376
+ "device_auth": {
377
+ "device_id": hostname,
378
+ "device_token": hostname
379
+ },
380
+ "client_ts_ms": int(time.time() * 1000)
381
+ }
382
+
383
+ success, result = send_data(payload, api_endpoint, headers)
384
+ if not success:
385
+ raise RemoteServiceError(result)
386
+ return result
387
+
388
+
389
+ def get_download_url_by_data_id(id, api_endpoint=DEFAULT_ENDPOINT, headers=None):
390
+ payload = {
391
+ "id": id
392
+ }
393
+ success, result = get_download_url(payload, api_endpoint, headers)
394
+ if not success:
395
+ raise RemoteServiceError(result)
396
+
397
+ url = result.get("data", "")
398
+ if not url:
399
+ raise ValueError("获取的下载连接为空")
400
+
401
+ return url
402
+
403
+ if __name__ == "__main__":
404
+ data = get_collection_task_by_id("b6e11e9b-84c1-4984-bbb8-ae6be21d002e")
405
+ from pprint import pprint
406
+ pprint(data)