hirundo 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +1 -1
- hirundo/_env.py +19 -5
- hirundo/_headers.py +8 -4
- hirundo/_http.py +14 -0
- hirundo/_iter_sse_retrying.py +2 -2
- hirundo/cli.py +80 -17
- hirundo/dataset_optimization.py +233 -53
- hirundo/git.py +17 -15
- hirundo/logger.py +10 -0
- hirundo/storage.py +57 -22
- hirundo-0.1.8.dist-info/METADATA +176 -0
- hirundo-0.1.8.dist-info/RECORD +20 -0
- {hirundo-0.1.6.dist-info → hirundo-0.1.8.dist-info}/WHEEL +1 -1
- hirundo-0.1.6.dist-info/METADATA +0 -117
- hirundo-0.1.6.dist-info/RECORD +0 -18
- {hirundo-0.1.6.dist-info → hirundo-0.1.8.dist-info}/LICENSE +0 -0
- {hirundo-0.1.6.dist-info → hirundo-0.1.8.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.6.dist-info → hirundo-0.1.8.dist-info}/top_level.txt +0 -0
hirundo/dataset_optimization.py
CHANGED
|
@@ -1,22 +1,29 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import
|
|
2
|
+
import typing
|
|
3
3
|
from collections.abc import AsyncGenerator, Generator
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from io import StringIO
|
|
5
|
-
from typing import
|
|
6
|
+
from typing import overload
|
|
6
7
|
|
|
7
8
|
import httpx
|
|
9
|
+
import numpy as np
|
|
8
10
|
import pandas as pd
|
|
9
11
|
import requests
|
|
12
|
+
from pandas._typing import DtypeArg
|
|
10
13
|
from pydantic import BaseModel, Field, model_validator
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
11
16
|
|
|
12
17
|
from hirundo._env import API_HOST
|
|
13
|
-
from hirundo._headers import
|
|
18
|
+
from hirundo._headers import get_auth_headers, json_headers
|
|
19
|
+
from hirundo._http import raise_for_status_with_reason
|
|
14
20
|
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
15
21
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
16
22
|
from hirundo.enum import DatasetMetadataType, LabellingType
|
|
23
|
+
from hirundo.logger import get_logger
|
|
17
24
|
from hirundo.storage import StorageIntegration, StorageLink
|
|
18
25
|
|
|
19
|
-
logger =
|
|
26
|
+
logger = get_logger(__name__)
|
|
20
27
|
|
|
21
28
|
|
|
22
29
|
class HirundoError(Exception):
|
|
@@ -30,6 +37,66 @@ class HirundoError(Exception):
|
|
|
30
37
|
MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
|
|
31
38
|
|
|
32
39
|
|
|
40
|
+
class RunStatus(Enum):
|
|
41
|
+
STARTED = "STARTED"
|
|
42
|
+
PENDING = "PENDING"
|
|
43
|
+
SUCCESS = "SUCCESS"
|
|
44
|
+
FAILURE = "FAILURE"
|
|
45
|
+
AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
|
|
46
|
+
RETRYING = "RETRYING"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
STATUS_TO_TEXT_MAP = {
|
|
50
|
+
RunStatus.STARTED.value: "Optimization run in progress. Downloading dataset",
|
|
51
|
+
RunStatus.PENDING.value: "Optimization run queued and not yet started",
|
|
52
|
+
RunStatus.SUCCESS.value: "Optimization run completed successfully",
|
|
53
|
+
RunStatus.FAILURE.value: "Optimization run failed",
|
|
54
|
+
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
55
|
+
RunStatus.RETRYING.value: "Optimization run failed. Retrying",
|
|
56
|
+
}
|
|
57
|
+
STATUS_TO_PROGRESS_MAP = {
|
|
58
|
+
RunStatus.STARTED.value: 0.0,
|
|
59
|
+
RunStatus.PENDING.value: 0.0,
|
|
60
|
+
RunStatus.SUCCESS.value: 100.0,
|
|
61
|
+
RunStatus.FAILURE.value: 100.0,
|
|
62
|
+
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
63
|
+
RunStatus.RETRYING.value: 0.0,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class DatasetOptimizationResults(BaseModel):
|
|
68
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
69
|
+
|
|
70
|
+
suspects: pd.DataFrame
|
|
71
|
+
"""
|
|
72
|
+
A pandas DataFrame containing the results of the optimization run
|
|
73
|
+
"""
|
|
74
|
+
warnings_and_errors: pd.DataFrame
|
|
75
|
+
"""
|
|
76
|
+
A pandas DataFrame containing the warnings and errors of the optimization run
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
|
|
81
|
+
"image_path": str,
|
|
82
|
+
"label_path": str,
|
|
83
|
+
"segments_mask_path": str,
|
|
84
|
+
"segment_id": np.int32,
|
|
85
|
+
"label": str,
|
|
86
|
+
"bbox_id": str,
|
|
87
|
+
"xmin": np.int32,
|
|
88
|
+
"ymin": np.int32,
|
|
89
|
+
"xmax": np.int32,
|
|
90
|
+
"ymax": np.int32,
|
|
91
|
+
"suspect_level": np.float32, # If exists, must be one of the values in the enum below
|
|
92
|
+
"suggested_label": str,
|
|
93
|
+
"suggested_label_conf": np.float32,
|
|
94
|
+
"status": str,
|
|
95
|
+
# ⬆️ If exists, must be one of the following:
|
|
96
|
+
# NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
|
|
33
100
|
class OptimizationDataset(BaseModel):
|
|
34
101
|
name: str
|
|
35
102
|
"""
|
|
@@ -42,13 +109,13 @@ class OptimizationDataset(BaseModel):
|
|
|
42
109
|
- `LabellingType.SingleLabelClassification`: Indicates that the dataset is for classification tasks
|
|
43
110
|
- `LabellingType.ObjectDetection`: Indicates that the dataset is for object detection tasks
|
|
44
111
|
"""
|
|
45
|
-
dataset_storage:
|
|
112
|
+
dataset_storage: typing.Optional[StorageLink]
|
|
46
113
|
"""
|
|
47
114
|
The storage link to the dataset. This can be a link to a file or a directory containing the dataset.
|
|
48
115
|
If `None`, the `dataset_id` field must be set.
|
|
49
116
|
"""
|
|
50
117
|
|
|
51
|
-
classes: list[str]
|
|
118
|
+
classes: typing.Optional[list[str]] = None
|
|
52
119
|
"""
|
|
53
120
|
A full list of possible classes used in classification / object detection.
|
|
54
121
|
It is currently required for clarity and performance.
|
|
@@ -66,15 +133,15 @@ class OptimizationDataset(BaseModel):
|
|
|
66
133
|
Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
|
|
67
134
|
"""
|
|
68
135
|
|
|
69
|
-
storage_integration_id:
|
|
136
|
+
storage_integration_id: typing.Optional[int] = Field(default=None, init=False)
|
|
70
137
|
"""
|
|
71
138
|
The ID of the storage integration used to store the dataset and metadata.
|
|
72
139
|
"""
|
|
73
|
-
dataset_id:
|
|
140
|
+
dataset_id: typing.Optional[int] = Field(default=None, init=False)
|
|
74
141
|
"""
|
|
75
142
|
The ID of the dataset created on the server.
|
|
76
143
|
"""
|
|
77
|
-
run_id:
|
|
144
|
+
run_id: typing.Optional[str] = Field(default=None, init=False)
|
|
78
145
|
"""
|
|
79
146
|
The ID of the Dataset Optimization run created on the server.
|
|
80
147
|
"""
|
|
@@ -86,7 +153,7 @@ class OptimizationDataset(BaseModel):
|
|
|
86
153
|
return self
|
|
87
154
|
|
|
88
155
|
@staticmethod
|
|
89
|
-
def list(organization_id:
|
|
156
|
+
def list(organization_id: typing.Optional[int] = None) -> list[dict]:
|
|
90
157
|
"""
|
|
91
158
|
Lists all the `OptimizationDataset` instances created by user's default organization
|
|
92
159
|
or the `organization_id` passed
|
|
@@ -98,10 +165,10 @@ class OptimizationDataset(BaseModel):
|
|
|
98
165
|
response = requests.get(
|
|
99
166
|
f"{API_HOST}/dataset-optimization/dataset/",
|
|
100
167
|
params={"dataset_organization_id": organization_id},
|
|
101
|
-
headers=
|
|
168
|
+
headers=get_auth_headers(),
|
|
102
169
|
timeout=READ_TIMEOUT,
|
|
103
170
|
)
|
|
104
|
-
response
|
|
171
|
+
raise_for_status_with_reason(response)
|
|
105
172
|
return response.json()
|
|
106
173
|
|
|
107
174
|
@staticmethod
|
|
@@ -114,10 +181,11 @@ class OptimizationDataset(BaseModel):
|
|
|
114
181
|
"""
|
|
115
182
|
response = requests.delete(
|
|
116
183
|
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
117
|
-
headers=
|
|
184
|
+
headers=get_auth_headers(),
|
|
118
185
|
timeout=MODIFY_TIMEOUT,
|
|
119
186
|
)
|
|
120
|
-
response
|
|
187
|
+
raise_for_status_with_reason(response)
|
|
188
|
+
logger.info("Deleted dataset with ID: %s", dataset_id)
|
|
121
189
|
|
|
122
190
|
def delete(self, storage_integration=True) -> None:
|
|
123
191
|
"""
|
|
@@ -167,14 +235,15 @@ class OptimizationDataset(BaseModel):
|
|
|
167
235
|
},
|
|
168
236
|
headers={
|
|
169
237
|
**json_headers,
|
|
170
|
-
**
|
|
238
|
+
**get_auth_headers(),
|
|
171
239
|
},
|
|
172
240
|
timeout=MODIFY_TIMEOUT,
|
|
173
241
|
)
|
|
174
|
-
dataset_response
|
|
242
|
+
raise_for_status_with_reason(dataset_response)
|
|
175
243
|
self.dataset_id = dataset_response.json()["id"]
|
|
176
244
|
if not self.dataset_id:
|
|
177
245
|
raise HirundoError("Failed to create the dataset")
|
|
246
|
+
logger.info("Created dataset with ID: %s", self.dataset_id)
|
|
178
247
|
return self.dataset_id
|
|
179
248
|
|
|
180
249
|
@staticmethod
|
|
@@ -191,10 +260,10 @@ class OptimizationDataset(BaseModel):
|
|
|
191
260
|
"""
|
|
192
261
|
run_response = requests.post(
|
|
193
262
|
f"{API_HOST}/dataset-optimization/run/{dataset_id}",
|
|
194
|
-
headers=
|
|
263
|
+
headers=get_auth_headers(),
|
|
195
264
|
timeout=MODIFY_TIMEOUT,
|
|
196
265
|
)
|
|
197
|
-
run_response
|
|
266
|
+
raise_for_status_with_reason(run_response)
|
|
198
267
|
return run_response.json()["run_id"]
|
|
199
268
|
|
|
200
269
|
def run_optimization(self) -> str:
|
|
@@ -210,6 +279,7 @@ class OptimizationDataset(BaseModel):
|
|
|
210
279
|
self.dataset_id = self.create()
|
|
211
280
|
run_id = self.launch_optimization_run(self.dataset_id)
|
|
212
281
|
self.run_id = run_id
|
|
282
|
+
logger.info("Started the run with ID: %s", run_id)
|
|
213
283
|
return run_id
|
|
214
284
|
except requests.HTTPError as error:
|
|
215
285
|
try:
|
|
@@ -237,30 +307,47 @@ class OptimizationDataset(BaseModel):
|
|
|
237
307
|
self.run_id = None
|
|
238
308
|
|
|
239
309
|
@staticmethod
|
|
240
|
-
def
|
|
241
|
-
if data["state"] == "SUCCESS":
|
|
242
|
-
data["result"] = pd.read_csv(StringIO(data["result"]))
|
|
243
|
-
else:
|
|
244
|
-
pass
|
|
245
|
-
|
|
246
|
-
@staticmethod
|
|
247
|
-
def check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
310
|
+
def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
|
|
248
311
|
"""
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
This generator will produce values to show progress of the run.
|
|
312
|
+
Clean the index of a dataframe in case it has unnamed columns.
|
|
252
313
|
|
|
253
314
|
Args:
|
|
254
|
-
|
|
255
|
-
retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
|
|
256
|
-
|
|
257
|
-
Yields:
|
|
258
|
-
Each event will be a dict, where:
|
|
259
|
-
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
260
|
-
- `"result"` is a string describing the progress as a percentage for a PENDING state,
|
|
261
|
-
or the error for a FAILURE state or the results for a SUCCESS state
|
|
315
|
+
df (DataFrame): Dataframe to clean
|
|
262
316
|
|
|
317
|
+
Returns:
|
|
318
|
+
DataFrame: Cleaned dataframe
|
|
263
319
|
"""
|
|
320
|
+
index_cols = sorted(
|
|
321
|
+
[col for col in df.columns if col.startswith("Unnamed")], reverse=True
|
|
322
|
+
)
|
|
323
|
+
if len(index_cols) > 0:
|
|
324
|
+
df.set_index(index_cols.pop(), inplace=True)
|
|
325
|
+
df.rename_axis(index=None, columns=None, inplace=True)
|
|
326
|
+
if len(index_cols) > 0:
|
|
327
|
+
df.drop(columns=index_cols, inplace=True)
|
|
328
|
+
|
|
329
|
+
return df
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _read_csvs_to_df(data: dict):
|
|
333
|
+
if data["state"] == RunStatus.SUCCESS.value:
|
|
334
|
+
data["result"]["suspects"] = OptimizationDataset._clean_df_index(
|
|
335
|
+
pd.read_csv(
|
|
336
|
+
StringIO(data["result"]["suspects"]),
|
|
337
|
+
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
338
|
+
)
|
|
339
|
+
)
|
|
340
|
+
data["result"]["warnings_and_errors"] = OptimizationDataset._clean_df_index(
|
|
341
|
+
pd.read_csv(
|
|
342
|
+
StringIO(data["result"]["warnings_and_errors"]),
|
|
343
|
+
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
pass
|
|
348
|
+
|
|
349
|
+
@staticmethod
|
|
350
|
+
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
264
351
|
if retry > MAX_RETRIES:
|
|
265
352
|
raise HirundoError("Max retries reached")
|
|
266
353
|
last_event = None
|
|
@@ -269,7 +356,7 @@ class OptimizationDataset(BaseModel):
|
|
|
269
356
|
client,
|
|
270
357
|
"GET",
|
|
271
358
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
272
|
-
headers=
|
|
359
|
+
headers=get_auth_headers(),
|
|
273
360
|
):
|
|
274
361
|
if sse.event == "ping":
|
|
275
362
|
continue
|
|
@@ -284,26 +371,117 @@ class OptimizationDataset(BaseModel):
|
|
|
284
371
|
if not last_event:
|
|
285
372
|
continue
|
|
286
373
|
data = last_event["data"]
|
|
287
|
-
OptimizationDataset.
|
|
374
|
+
OptimizationDataset._read_csvs_to_df(data)
|
|
288
375
|
yield data
|
|
289
|
-
if not last_event or last_event["data"]["state"] ==
|
|
290
|
-
OptimizationDataset.
|
|
376
|
+
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
377
|
+
OptimizationDataset._check_run_by_id(run_id, retry + 1)
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
@overload
|
|
381
|
+
def check_run_by_id(
|
|
382
|
+
run_id: str, stop_on_manual_approval: typing.Literal[True]
|
|
383
|
+
) -> typing.Optional[DatasetOptimizationResults]: ...
|
|
291
384
|
|
|
292
|
-
|
|
385
|
+
@staticmethod
|
|
386
|
+
@overload
|
|
387
|
+
def check_run_by_id(
|
|
388
|
+
run_id: str, stop_on_manual_approval: typing.Literal[False] = False
|
|
389
|
+
) -> DatasetOptimizationResults: ...
|
|
390
|
+
|
|
391
|
+
@staticmethod
|
|
392
|
+
@overload
|
|
393
|
+
def check_run_by_id(
|
|
394
|
+
run_id: str, stop_on_manual_approval: bool
|
|
395
|
+
) -> typing.Optional[DatasetOptimizationResults]: ...
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
def check_run_by_id(
|
|
399
|
+
run_id: str, stop_on_manual_approval: bool = False
|
|
400
|
+
) -> typing.Optional[DatasetOptimizationResults]:
|
|
293
401
|
"""
|
|
294
|
-
Check the status of
|
|
402
|
+
Check the status of a run given its ID
|
|
295
403
|
|
|
296
|
-
|
|
404
|
+
Args:
|
|
405
|
+
run_id: The `run_id` produced by a `run_optimization` call
|
|
406
|
+
stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
|
|
297
407
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
408
|
+
Returns:
|
|
409
|
+
A DatasetOptimizationResults object with the results of the optimization run
|
|
410
|
+
|
|
411
|
+
Raises:
|
|
412
|
+
HirundoError: If the maximum number of retries is reached or if the run fails
|
|
413
|
+
"""
|
|
414
|
+
logger.debug("Checking run with ID: %s", run_id)
|
|
415
|
+
with logging_redirect_tqdm():
|
|
416
|
+
t = tqdm(total=100.0)
|
|
417
|
+
for iteration in OptimizationDataset._check_run_by_id(run_id):
|
|
418
|
+
if iteration["state"] in STATUS_TO_PROGRESS_MAP:
|
|
419
|
+
t.set_description(STATUS_TO_TEXT_MAP[iteration["state"]])
|
|
420
|
+
t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
|
|
421
|
+
logger.debug("Setting progress to %s", t.n)
|
|
422
|
+
t.refresh()
|
|
423
|
+
if iteration["state"] == RunStatus.FAILURE.value:
|
|
424
|
+
raise HirundoError(
|
|
425
|
+
f"Optimization run failed with error: {iteration['result']}"
|
|
426
|
+
)
|
|
427
|
+
elif iteration["state"] == RunStatus.SUCCESS.value:
|
|
428
|
+
t.close()
|
|
429
|
+
return DatasetOptimizationResults(
|
|
430
|
+
suspects=iteration["result"]["suspects"],
|
|
431
|
+
warnings_and_errors=iteration["result"][
|
|
432
|
+
"warnings_and_errors"
|
|
433
|
+
],
|
|
434
|
+
)
|
|
435
|
+
elif (
|
|
436
|
+
iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
437
|
+
and stop_on_manual_approval
|
|
438
|
+
):
|
|
439
|
+
t.close()
|
|
440
|
+
return None
|
|
441
|
+
elif iteration["state"] is None:
|
|
442
|
+
if (
|
|
443
|
+
iteration["result"]
|
|
444
|
+
and isinstance(iteration["result"], dict)
|
|
445
|
+
and iteration["result"]["result"]
|
|
446
|
+
and isinstance(iteration["result"]["result"], str)
|
|
447
|
+
):
|
|
448
|
+
current_progress_percentage = float(
|
|
449
|
+
iteration["result"]["result"].removesuffix("% done")
|
|
450
|
+
)
|
|
451
|
+
desc = (
|
|
452
|
+
"Optimization run completed. Uploading results"
|
|
453
|
+
if current_progress_percentage == 100.0
|
|
454
|
+
else "Optimization run in progress"
|
|
455
|
+
)
|
|
456
|
+
t.set_description(desc)
|
|
457
|
+
t.n = current_progress_percentage
|
|
458
|
+
logger.debug("Setting progress to %s", t.n)
|
|
459
|
+
t.refresh()
|
|
460
|
+
raise HirundoError("Optimization run failed with an unknown error")
|
|
461
|
+
|
|
462
|
+
@overload
|
|
463
|
+
def check_run(
|
|
464
|
+
self, stop_on_manual_approval: typing.Literal[True]
|
|
465
|
+
) -> typing.Optional[DatasetOptimizationResults]: ...
|
|
466
|
+
|
|
467
|
+
@overload
|
|
468
|
+
def check_run(
|
|
469
|
+
self, stop_on_manual_approval: typing.Literal[False] = False
|
|
470
|
+
) -> DatasetOptimizationResults: ...
|
|
471
|
+
|
|
472
|
+
def check_run(
|
|
473
|
+
self, stop_on_manual_approval: bool = False
|
|
474
|
+
) -> typing.Optional[DatasetOptimizationResults]:
|
|
475
|
+
"""
|
|
476
|
+
Check the status of the current active instance's run.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
A pandas DataFrame with the results of the optimization run
|
|
302
480
|
|
|
303
481
|
"""
|
|
304
482
|
if not self.run_id:
|
|
305
483
|
raise ValueError("No run has been started")
|
|
306
|
-
return self.check_run_by_id(self.run_id)
|
|
484
|
+
return self.check_run_by_id(self.run_id, stop_on_manual_approval)
|
|
307
485
|
|
|
308
486
|
@staticmethod
|
|
309
487
|
async def acheck_run_by_id(run_id: str, retry=0) -> AsyncGenerator[dict, None]:
|
|
@@ -324,6 +502,7 @@ class OptimizationDataset(BaseModel):
|
|
|
324
502
|
- `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
|
|
325
503
|
|
|
326
504
|
"""
|
|
505
|
+
logger.debug("Checking run with ID: %s", run_id)
|
|
327
506
|
if retry > MAX_RETRIES:
|
|
328
507
|
raise HirundoError("Max retries reached")
|
|
329
508
|
last_event = None
|
|
@@ -334,7 +513,7 @@ class OptimizationDataset(BaseModel):
|
|
|
334
513
|
client,
|
|
335
514
|
"GET",
|
|
336
515
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
337
|
-
headers=
|
|
516
|
+
headers=get_auth_headers(),
|
|
338
517
|
)
|
|
339
518
|
async for sse in async_iterator:
|
|
340
519
|
if sse.event == "ping":
|
|
@@ -348,7 +527,7 @@ class OptimizationDataset(BaseModel):
|
|
|
348
527
|
)
|
|
349
528
|
last_event = json.loads(sse.data)
|
|
350
529
|
yield last_event["data"]
|
|
351
|
-
if not last_event or last_event["data"]["state"] ==
|
|
530
|
+
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
352
531
|
OptimizationDataset.acheck_run_by_id(run_id, retry + 1)
|
|
353
532
|
|
|
354
533
|
async def acheck_run(self) -> AsyncGenerator[dict, None]:
|
|
@@ -380,12 +559,13 @@ class OptimizationDataset(BaseModel):
|
|
|
380
559
|
"""
|
|
381
560
|
if not run_id:
|
|
382
561
|
raise ValueError("No run has been started")
|
|
562
|
+
logger.info("Cancelling run with ID: %s", run_id)
|
|
383
563
|
response = requests.delete(
|
|
384
564
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
385
|
-
headers=
|
|
565
|
+
headers=get_auth_headers(),
|
|
386
566
|
timeout=MODIFY_TIMEOUT,
|
|
387
567
|
)
|
|
388
|
-
response
|
|
568
|
+
raise_for_status_with_reason(response)
|
|
389
569
|
|
|
390
570
|
def cancel(self) -> None:
|
|
391
571
|
"""
|
hirundo/git.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import re
|
|
3
|
-
|
|
2
|
+
import typing
|
|
3
|
+
from typing import Annotated
|
|
4
4
|
|
|
5
5
|
import pydantic
|
|
6
6
|
import requests
|
|
@@ -8,10 +8,12 @@ from pydantic import BaseModel, field_validator
|
|
|
8
8
|
from pydantic_core import Url
|
|
9
9
|
|
|
10
10
|
from hirundo._env import API_HOST
|
|
11
|
-
from hirundo._headers import
|
|
11
|
+
from hirundo._headers import get_auth_headers, json_headers
|
|
12
|
+
from hirundo._http import raise_for_status_with_reason
|
|
12
13
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
14
|
+
from hirundo.logger import get_logger
|
|
13
15
|
|
|
14
|
-
logger =
|
|
16
|
+
logger = get_logger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
class GitPlainAuthBase(BaseModel):
|
|
@@ -30,14 +32,14 @@ class GitSSHAuthBase(BaseModel):
|
|
|
30
32
|
"""
|
|
31
33
|
The SSH key for the Git repository
|
|
32
34
|
"""
|
|
33
|
-
ssh_password:
|
|
35
|
+
ssh_password: typing.Optional[str]
|
|
34
36
|
"""
|
|
35
37
|
The password for the SSH key for the Git repository.
|
|
36
38
|
"""
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
class GitRepo(BaseModel):
|
|
40
|
-
id:
|
|
42
|
+
id: typing.Optional[int] = None
|
|
41
43
|
"""
|
|
42
44
|
The ID of the Git repository.
|
|
43
45
|
"""
|
|
@@ -51,20 +53,20 @@ class GitRepo(BaseModel):
|
|
|
51
53
|
The URL of the Git repository, it should start with `ssh://` or `https://` or be in the form `user@host:path`.
|
|
52
54
|
If it is in the form `user@host:path`, it will be rewritten to `ssh://user@host:path`.
|
|
53
55
|
"""
|
|
54
|
-
organization_id:
|
|
56
|
+
organization_id: typing.Optional[int] = None
|
|
55
57
|
"""
|
|
56
58
|
The ID of the organization that the Git repository belongs to.
|
|
57
59
|
If not provided, it will be assigned to your default organization.
|
|
58
60
|
"""
|
|
59
61
|
|
|
60
|
-
plain_auth:
|
|
62
|
+
plain_auth: typing.Optional[GitPlainAuthBase] = pydantic.Field(
|
|
61
63
|
default=None, examples=[None, {"username": "ben", "password": "password"}]
|
|
62
64
|
)
|
|
63
65
|
"""
|
|
64
66
|
The plain authentication details for the Git repository.
|
|
65
67
|
Use this if using a special user with a username and password for authentication.
|
|
66
68
|
"""
|
|
67
|
-
ssh_auth:
|
|
69
|
+
ssh_auth: typing.Optional[GitSSHAuthBase] = pydantic.Field(
|
|
68
70
|
default=None,
|
|
69
71
|
examples=[
|
|
70
72
|
{
|
|
@@ -108,11 +110,11 @@ class GitRepo(BaseModel):
|
|
|
108
110
|
json=self.model_dump(),
|
|
109
111
|
headers={
|
|
110
112
|
**json_headers,
|
|
111
|
-
**
|
|
113
|
+
**get_auth_headers(),
|
|
112
114
|
},
|
|
113
115
|
timeout=MODIFY_TIMEOUT,
|
|
114
116
|
)
|
|
115
|
-
git_repo
|
|
117
|
+
raise_for_status_with_reason(git_repo)
|
|
116
118
|
git_repo_id = git_repo.json()["id"]
|
|
117
119
|
self.id = git_repo_id
|
|
118
120
|
return git_repo_id
|
|
@@ -125,11 +127,11 @@ class GitRepo(BaseModel):
|
|
|
125
127
|
git_repos = requests.get(
|
|
126
128
|
f"{API_HOST}/git-repo/",
|
|
127
129
|
headers={
|
|
128
|
-
**
|
|
130
|
+
**get_auth_headers(),
|
|
129
131
|
},
|
|
130
132
|
timeout=READ_TIMEOUT,
|
|
131
133
|
)
|
|
132
|
-
git_repos
|
|
134
|
+
raise_for_status_with_reason(git_repos)
|
|
133
135
|
return git_repos.json()
|
|
134
136
|
|
|
135
137
|
@staticmethod
|
|
@@ -143,11 +145,11 @@ class GitRepo(BaseModel):
|
|
|
143
145
|
git_repo = requests.delete(
|
|
144
146
|
f"{API_HOST}/git-repo/{git_repo_id}",
|
|
145
147
|
headers={
|
|
146
|
-
**
|
|
148
|
+
**get_auth_headers(),
|
|
147
149
|
},
|
|
148
150
|
timeout=MODIFY_TIMEOUT,
|
|
149
151
|
)
|
|
150
|
-
git_repo
|
|
152
|
+
raise_for_status_with_reason(git_repo)
|
|
151
153
|
|
|
152
154
|
def delete(self):
|
|
153
155
|
"""
|
hirundo/logger.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_logger(name: str) -> logging.Logger:
|
|
6
|
+
logger = logging.getLogger(name)
|
|
7
|
+
log_level = os.getenv("LOG_LEVEL")
|
|
8
|
+
logger.setLevel(log_level if log_level else logging.INFO)
|
|
9
|
+
logger.addHandler(logging.StreamHandler())
|
|
10
|
+
return logger
|