hirundo 0.1.21__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hirundo/dataset_qa.py CHANGED
@@ -1,11 +1,9 @@
1
1
  import datetime
2
- import json
3
2
  import typing
4
3
  from collections.abc import AsyncGenerator, Generator
5
4
  from enum import Enum
6
5
  from typing import overload
7
6
 
8
- import httpx
9
7
  from pydantic import BaseModel, Field, model_validator
10
8
  from tqdm import tqdm
11
9
  from tqdm.contrib.logging import logging_redirect_tqdm
@@ -14,7 +12,16 @@ from hirundo._constraints import validate_labeling_info, validate_url
14
12
  from hirundo._env import API_HOST
15
13
  from hirundo._headers import get_headers
16
14
  from hirundo._http import raise_for_status_with_reason, requests
17
- from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
15
+ from hirundo._run_checking import (
16
+ STATUS_TO_PROGRESS_MAP,
17
+ RunStatus,
18
+ aiter_run_events,
19
+ build_status_text_map,
20
+ get_state,
21
+ handle_run_failure,
22
+ iter_run_events,
23
+ update_progress_from_result,
24
+ )
18
25
  from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
19
26
  from hirundo._urls import HirundoUrl
20
27
  from hirundo.dataset_enum import DatasetMetadataType, LabelingType
@@ -35,87 +42,57 @@ class HirundoError(Exception):
35
42
  pass
36
43
 
37
44
 
38
- MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
39
-
40
-
41
- class RunStatus(Enum):
42
- PENDING = "PENDING"
43
- STARTED = "STARTED"
44
- SUCCESS = "SUCCESS"
45
- FAILURE = "FAILURE"
46
- AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
47
- REVOKED = "REVOKED"
48
- REJECTED = "REJECTED"
49
- RETRY = "RETRY"
50
-
51
-
52
- STATUS_TO_TEXT_MAP = {
53
- RunStatus.STARTED.value: "Dataset QA run in progress. Downloading dataset",
54
- RunStatus.PENDING.value: "Dataset QA run queued and not yet started",
55
- RunStatus.SUCCESS.value: "Dataset QA run completed successfully",
56
- RunStatus.FAILURE.value: "Dataset QA run failed",
57
- RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
58
- RunStatus.RETRY.value: "Dataset QA run failed. Retrying",
59
- RunStatus.REVOKED.value: "Dataset QA run was cancelled",
60
- RunStatus.REJECTED.value: "Dataset QA run was rejected",
61
- }
62
- STATUS_TO_PROGRESS_MAP = {
63
- RunStatus.STARTED.value: 0.0,
64
- RunStatus.PENDING.value: 0.0,
65
- RunStatus.SUCCESS.value: 100.0,
66
- RunStatus.FAILURE.value: 100.0,
67
- RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
68
- RunStatus.RETRY.value: 0.0,
69
- RunStatus.REVOKED.value: 100.0,
70
- RunStatus.REJECTED.value: 0.0,
71
- }
45
+ STATUS_TO_TEXT_MAP = build_status_text_map(
46
+ "Dataset QA",
47
+ started_detail="Dataset QA run in progress. Downloading dataset",
48
+ )
72
49
 
73
50
 
74
51
  class ClassificationRunArgs(BaseModel):
75
- image_size: typing.Optional[tuple[int, int]] = (224, 224)
52
+ image_size: tuple[int, int] | None = (224, 224)
76
53
  """
77
54
  Size (width, height) to which to resize classification images.
78
55
  It is recommended to keep this value at (224, 224) unless your classes are differentiated by very small differences.
79
56
  """
80
- upsample: typing.Optional[bool] = False
57
+ upsample: bool | None = False
81
58
  """
82
59
  Whether to upsample the dataset to attempt to balance the classes.
83
60
  """
84
61
 
85
62
 
86
63
  class ObjectDetectionRunArgs(ClassificationRunArgs):
87
- min_abs_bbox_size: typing.Optional[int] = None
64
+ min_abs_bbox_size: int | None = None
88
65
  """
89
66
  Minimum valid size (in pixels) of a bounding box to keep it in the dataset for QA.
90
67
  """
91
- min_abs_bbox_area: typing.Optional[int] = None
68
+ min_abs_bbox_area: int | None = None
92
69
  """
93
70
  Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for QA.
94
71
  """
95
- min_rel_bbox_size: typing.Optional[float] = None
72
+ min_rel_bbox_size: float | None = None
96
73
  """
97
74
  Minimum valid size (as a fraction of both image height and width) for a bounding box
98
75
  to keep it in the dataset for QA, relative to the corresponding dimension size,
99
76
  i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
100
77
  value is 0.06 (since both width and height are checked).
101
78
  """
102
- min_rel_bbox_area: typing.Optional[float] = None
79
+ min_rel_bbox_area: float | None = None
103
80
  """
104
81
  Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for QA.
105
82
  """
106
- crop_ratio: typing.Optional[float] = None
83
+ crop_ratio: float | None = None
107
84
  """
108
85
  Ratio of the bounding box to crop.
109
86
  Change this value at your own risk. It is recommended to keep it at 1.0 unless you know what you are doing.
110
87
  """
111
- add_mask_channel: typing.Optional[bool] = None
88
+ add_mask_channel: bool | None = None
112
89
  """
113
90
  Whether to add a mask channel to the image.
114
91
  Change at your own risk. It is recommended to keep it at False unless you know what you are doing.
115
92
  """
116
93
 
117
94
 
118
- RunArgs = typing.Union[ClassificationRunArgs, ObjectDetectionRunArgs]
95
+ RunArgs = ClassificationRunArgs | ObjectDetectionRunArgs
119
96
 
120
97
 
121
98
  class AugmentationName(str, Enum):
@@ -128,32 +105,32 @@ class AugmentationName(str, Enum):
128
105
  GAUSSIAN_BLUR = "GaussianBlur"
129
106
 
130
107
 
131
- class Domain(str, Enum):
108
+ class ModalityType(str, Enum):
132
109
  RADAR = "RADAR"
133
110
  VISION = "VISION"
134
111
  SPEECH = "SPEECH"
135
112
  TABULAR = "TABULAR"
136
113
 
137
114
 
138
- DOMAIN_TO_SUPPORTED_LABELING_TYPES = {
139
- Domain.RADAR: [
115
+ MODALITY_TO_SUPPORTED_LABELING_TYPES = {
116
+ ModalityType.RADAR: [
140
117
  LabelingType.SINGLE_LABEL_CLASSIFICATION,
141
118
  LabelingType.OBJECT_DETECTION,
142
119
  ],
143
- Domain.VISION: [
120
+ ModalityType.VISION: [
144
121
  LabelingType.SINGLE_LABEL_CLASSIFICATION,
145
122
  LabelingType.OBJECT_DETECTION,
146
123
  LabelingType.OBJECT_SEGMENTATION,
147
124
  LabelingType.SEMANTIC_SEGMENTATION,
148
125
  LabelingType.PANOPTIC_SEGMENTATION,
149
126
  ],
150
- Domain.SPEECH: [LabelingType.SPEECH_TO_TEXT],
151
- Domain.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
127
+ ModalityType.SPEECH: [LabelingType.SPEECH_TO_TEXT],
128
+ ModalityType.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
152
129
  }
153
130
 
154
131
 
155
132
  class QADataset(BaseModel):
156
- id: typing.Optional[int] = Field(default=None)
133
+ id: int | None = Field(default=None)
157
134
  """
158
135
  The ID of the dataset created on the server.
159
136
  """
@@ -169,17 +146,15 @@ class QADataset(BaseModel):
169
146
  - `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
170
147
  - `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
171
148
  """
172
- language: typing.Optional[str] = None
149
+ language: str | None = None
173
150
  """
174
151
  Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
175
152
  """
176
- storage_config_id: typing.Optional[int] = None
153
+ storage_config_id: int | None = None
177
154
  """
178
155
  The ID of the storage config used to store the dataset and metadata.
179
156
  """
180
- storage_config: typing.Optional[
181
- typing.Union[StorageConfig, ResponseStorageConfig]
182
- ] = None
157
+ storage_config: StorageConfig | ResponseStorageConfig | None = None
183
158
  """
184
159
  The `StorageConfig` instance to link to.
185
160
  """
@@ -193,41 +168,44 @@ class QADataset(BaseModel):
193
168
  Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
194
169
  """
195
170
 
196
- classes: typing.Optional[list[str]] = None
171
+ classes: list[str] | None = None
197
172
  """
198
173
  A full list of possible classes used in classification / object detection.
199
174
  It is currently required for clarity and performance.
200
175
  """
201
- labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
176
+ labeling_info: LabelingInfo | list[LabelingInfo]
202
177
 
203
- augmentations: typing.Optional[list[AugmentationName]] = None
178
+ augmentations: list[AugmentationName] | None = None
204
179
  """
205
180
  Used to define which augmentations are apply to a vision dataset.
206
181
  For audio datasets, this field is ignored.
207
182
  If no value is provided, all augmentations are applied to vision datasets.
208
183
  """
209
- domain: Domain = Domain.VISION
184
+ modality: ModalityType = ModalityType.VISION
210
185
  """
211
- Used to define the domain of the dataset.
186
+ Used to define the modality of the dataset.
212
187
  Defaults to Image.
213
188
  """
214
189
 
215
- run_id: typing.Optional[str] = Field(default=None, init=False)
190
+ run_id: str | None = Field(default=None, init=False)
216
191
  """
217
192
  The ID of the Dataset QA run created on the server.
218
193
  """
219
194
 
220
- status: typing.Optional[RunStatus] = None
195
+ status: RunStatus | None = None
221
196
 
222
197
  @model_validator(mode="after")
223
198
  def validate_dataset(self):
224
- if self.domain not in DOMAIN_TO_SUPPORTED_LABELING_TYPES:
199
+ if self.modality not in MODALITY_TO_SUPPORTED_LABELING_TYPES:
225
200
  raise ValueError(
226
- f"Domain {self.domain} is not supported. Supported domains are: {list(DOMAIN_TO_SUPPORTED_LABELING_TYPES.keys())}"
201
+ f"Modality {self.modality} is not supported. Supported modalities are: {list(MODALITY_TO_SUPPORTED_LABELING_TYPES.keys())}"
227
202
  )
228
- if self.labeling_type not in DOMAIN_TO_SUPPORTED_LABELING_TYPES[self.domain]:
203
+ if (
204
+ self.labeling_type
205
+ not in MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]
206
+ ):
229
207
  raise ValueError(
230
- f"Labeling type {self.labeling_type} is not supported for domain {self.domain}. Supported labeling types are: {DOMAIN_TO_SUPPORTED_LABELING_TYPES[self.domain]}"
208
+ f"Labeling type {self.labeling_type} is not supported for modality {self.modality}. Supported labeling types are: {MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]}"
231
209
  )
232
210
  if self.storage_config is None and self.storage_config_id is None:
233
211
  raise ValueError(
@@ -307,7 +285,7 @@ class QADataset(BaseModel):
307
285
 
308
286
  @staticmethod
309
287
  def list_datasets(
310
- organization_id: typing.Optional[int] = None,
288
+ organization_id: int | None = None,
311
289
  ) -> list["QADatasetOut"]:
312
290
  """
313
291
  Lists all the datasets created by user's default organization
@@ -333,7 +311,8 @@ class QADataset(BaseModel):
333
311
 
334
312
  @staticmethod
335
313
  def list_runs(
336
- organization_id: typing.Optional[int] = None,
314
+ organization_id: int | None = None,
315
+ archived: bool | None = False,
337
316
  ) -> list["DataQARunOut"]:
338
317
  """
339
318
  Lists all the `QADataset` instances created by user's default organization
@@ -342,10 +321,11 @@ class QADataset(BaseModel):
342
321
 
343
322
  Args:
344
323
  organization_id: The ID of the organization to list the datasets for.
324
+ archived: Whether to list archived runs.
345
325
  """
346
326
  response = requests.get(
347
327
  f"{API_HOST}/dataset-qa/run/list",
348
- params={"dataset_organization_id": organization_id},
328
+ params={"dataset_organization_id": organization_id, "archived": archived},
349
329
  headers=get_headers(),
350
330
  timeout=READ_TIMEOUT,
351
331
  )
@@ -396,7 +376,7 @@ class QADataset(BaseModel):
396
376
 
397
377
  def create(
398
378
  self,
399
- organization_id: typing.Optional[int] = None,
379
+ organization_id: int | None = None,
400
380
  replace_if_exists: bool = False,
401
381
  ) -> int:
402
382
  """
@@ -453,8 +433,8 @@ class QADataset(BaseModel):
453
433
  @staticmethod
454
434
  def launch_qa_run(
455
435
  dataset_id: int,
456
- organization_id: typing.Optional[int] = None,
457
- run_args: typing.Optional[RunArgs] = None,
436
+ organization_id: int | None = None,
437
+ run_args: RunArgs | None = None,
458
438
  ) -> str:
459
439
  """
460
440
  Run the dataset QA process on the server using the dataset with the given ID
@@ -503,9 +483,9 @@ class QADataset(BaseModel):
503
483
 
504
484
  def run_qa(
505
485
  self,
506
- organization_id: typing.Optional[int] = None,
486
+ organization_id: int | None = None,
507
487
  replace_dataset_if_exists: bool = False,
508
- run_args: typing.Optional[RunArgs] = None,
488
+ run_args: RunArgs | None = None,
509
489
  ) -> str:
510
490
  """
511
491
  If the dataset was not created on the server yet, it is created.
@@ -556,53 +536,20 @@ class QADataset(BaseModel):
556
536
 
557
537
  @staticmethod
558
538
  def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
559
- if retry > MAX_RETRIES:
560
- raise HirundoError("Max retries reached")
561
- last_event = None
562
- with httpx.Client(timeout=httpx.Timeout(None, connect=5.0)) as client:
563
- for sse in iter_sse_retrying(
564
- client,
565
- "GET",
566
- f"{API_HOST}/dataset-qa/run/{run_id}",
567
- headers=get_headers(),
568
- ):
569
- if sse.event == "ping":
570
- continue
571
- logger.debug(
572
- "[SYNC] received event: %s with data: %s and ID: %s and retry: %s",
573
- sse.event,
574
- sse.data,
575
- sse.id,
576
- sse.retry,
577
- )
578
- last_event = json.loads(sse.data)
579
- if not last_event:
580
- continue
581
- if "data" in last_event:
582
- data = last_event["data"]
583
- else:
584
- if "detail" in last_event:
585
- raise HirundoError(last_event["detail"])
586
- elif "reason" in last_event:
587
- raise HirundoError(last_event["reason"])
588
- else:
589
- raise HirundoError("Unknown error")
590
- yield data
591
- if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
592
- QADataset._check_run_by_id(run_id, retry + 1)
593
-
594
- @staticmethod
595
- def _handle_failure(iteration: dict):
596
- if iteration["result"]:
597
- raise HirundoError(f"QA run failed with error: {iteration['result']}")
598
- else:
599
- raise HirundoError("QA run failed with an unknown error in _handle_failure")
539
+ yield from iter_run_events(
540
+ f"{API_HOST}/dataset-qa/run/{run_id}",
541
+ headers=get_headers(),
542
+ retry=retry,
543
+ status_keys=("state",),
544
+ error_cls=HirundoError,
545
+ log=logger,
546
+ )
600
547
 
601
548
  @staticmethod
602
549
  @overload
603
550
  def check_run_by_id(
604
551
  run_id: str, stop_on_manual_approval: typing.Literal[True]
605
- ) -> typing.Optional[DatasetQAResults]: ...
552
+ ) -> DatasetQAResults | None: ...
606
553
 
607
554
  @staticmethod
608
555
  @overload
@@ -614,12 +561,12 @@ class QADataset(BaseModel):
614
561
  @overload
615
562
  def check_run_by_id(
616
563
  run_id: str, stop_on_manual_approval: bool
617
- ) -> typing.Optional[DatasetQAResults]: ...
564
+ ) -> DatasetQAResults | None: ...
618
565
 
619
566
  @staticmethod
620
567
  def check_run_by_id(
621
568
  run_id: str, stop_on_manual_approval: bool = False
622
- ) -> typing.Optional[DatasetQAResults]:
569
+ ) -> DatasetQAResults | None:
623
570
  """
624
571
  Check the status of a run given its ID
625
572
 
@@ -637,22 +584,25 @@ class QADataset(BaseModel):
637
584
  with logging_redirect_tqdm():
638
585
  t = tqdm(total=100.0)
639
586
  for iteration in QADataset._check_run_by_id(run_id):
640
- if iteration["state"] in STATUS_TO_PROGRESS_MAP:
641
- t.set_description(STATUS_TO_TEXT_MAP[iteration["state"]])
642
- t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
587
+ state = get_state(iteration, ("state",))
588
+ if state in STATUS_TO_PROGRESS_MAP:
589
+ t.set_description(STATUS_TO_TEXT_MAP[state])
590
+ t.n = STATUS_TO_PROGRESS_MAP[state]
643
591
  logger.debug("Setting progress to %s", t.n)
644
592
  t.refresh()
645
- if iteration["state"] in [
593
+ if state in [
646
594
  RunStatus.FAILURE.value,
647
595
  RunStatus.REJECTED.value,
648
596
  RunStatus.REVOKED.value,
649
597
  ]:
650
598
  logger.error(
651
599
  "State is failure, rejected, or revoked: %s",
652
- iteration["state"],
600
+ state,
653
601
  )
654
- QADataset._handle_failure(iteration)
655
- elif iteration["state"] == RunStatus.SUCCESS.value:
602
+ handle_run_failure(
603
+ iteration, error_cls=HirundoError, run_label="QA"
604
+ )
605
+ elif state == RunStatus.SUCCESS.value:
656
606
  t.close()
657
607
  zip_temporary_url = iteration["result"]
658
608
  logger.debug("QA run completed. Downloading results")
@@ -662,45 +612,24 @@ class QADataset(BaseModel):
662
612
  zip_temporary_url,
663
613
  )
664
614
  elif (
665
- iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
615
+ state == RunStatus.AWAITING_MANUAL_APPROVAL.value
666
616
  and stop_on_manual_approval
667
617
  ):
668
618
  t.close()
669
619
  return None
670
- elif iteration["state"] is None:
671
- if (
672
- iteration["result"]
673
- and isinstance(iteration["result"], dict)
674
- and iteration["result"]["result"]
675
- and isinstance(iteration["result"]["result"], str)
676
- ):
677
- result_info = iteration["result"]["result"].split(":")
678
- if len(result_info) > 1:
679
- stage = result_info[0]
680
- current_progress_percentage = float(
681
- result_info[1].removeprefix(" ").removesuffix("% done")
682
- )
683
- elif len(result_info) == 1:
684
- stage = result_info[0]
685
- current_progress_percentage = t.n # Keep the same progress
686
- else:
687
- stage = "Unknown progress state"
688
- current_progress_percentage = t.n # Keep the same progress
689
- desc = (
690
- "QA run completed. Uploading results"
691
- if current_progress_percentage == 100.0
692
- else stage
693
- )
694
- t.set_description(desc)
695
- t.n = current_progress_percentage
696
- logger.debug("Setting progress to %s", t.n)
697
- t.refresh()
620
+ elif state is None:
621
+ update_progress_from_result(
622
+ iteration,
623
+ t,
624
+ uploading_text="QA run completed. Uploading results",
625
+ log=logger,
626
+ )
698
627
  raise HirundoError("QA run failed with an unknown error in check_run_by_id")
699
628
 
700
629
  @overload
701
630
  def check_run(
702
631
  self, stop_on_manual_approval: typing.Literal[True]
703
- ) -> typing.Optional[DatasetQAResults]: ...
632
+ ) -> DatasetQAResults | None: ...
704
633
 
705
634
  @overload
706
635
  def check_run(
@@ -709,7 +638,7 @@ class QADataset(BaseModel):
709
638
 
710
639
  def check_run(
711
640
  self, stop_on_manual_approval: bool = False
712
- ) -> typing.Optional[DatasetQAResults]:
641
+ ) -> DatasetQAResults | None:
713
642
  """
714
643
  Check the status of the current active instance's run.
715
644
 
@@ -741,32 +670,15 @@ class QADataset(BaseModel):
741
670
 
742
671
  """
743
672
  logger.debug("Checking run with ID: %s", run_id)
744
- if retry > MAX_RETRIES:
745
- raise HirundoError("Max retries reached")
746
- last_event = None
747
- async with httpx.AsyncClient(
748
- timeout=httpx.Timeout(None, connect=5.0)
749
- ) as client:
750
- async_iterator = await aiter_sse_retrying(
751
- client,
752
- "GET",
753
- f"{API_HOST}/dataset-qa/run/{run_id}",
754
- headers=get_headers(),
755
- )
756
- async for sse in async_iterator:
757
- if sse.event == "ping":
758
- continue
759
- logger.debug(
760
- "[ASYNC] Received event: %s with data: %s and ID: %s and retry: %s",
761
- sse.event,
762
- sse.data,
763
- sse.id,
764
- sse.retry,
765
- )
766
- last_event = json.loads(sse.data)
767
- yield last_event["data"]
768
- if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
769
- QADataset.acheck_run_by_id(run_id, retry + 1)
673
+ async for iteration in aiter_run_events(
674
+ f"{API_HOST}/dataset-qa/run/{run_id}",
675
+ headers=get_headers(),
676
+ retry=retry,
677
+ status_keys=("state",),
678
+ error_cls=HirundoError,
679
+ log=logger,
680
+ ):
681
+ yield iteration
770
682
 
771
683
  async def acheck_run(self) -> AsyncGenerator[dict, None]:
772
684
  """
@@ -776,6 +688,8 @@ class QADataset(BaseModel):
776
688
 
777
689
  This generator will produce values to show progress of the run.
778
690
 
691
+ Note: This function does not handle errors nor show progress. It is expected that you do that.
692
+
779
693
  Yields:
780
694
  Each event will be a dict, where:
781
695
  - `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
@@ -846,11 +760,11 @@ class QADatasetOut(BaseModel):
846
760
 
847
761
  data_root_url: HirundoUrl
848
762
 
849
- classes: typing.Optional[list[str]] = None
850
- labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
763
+ classes: list[str] | None = None
764
+ labeling_info: LabelingInfo | list[LabelingInfo]
851
765
 
852
- organization_id: typing.Optional[int]
853
- creator_id: typing.Optional[int]
766
+ organization_id: int | None
767
+ creator_id: int | None
854
768
  created_at: datetime.datetime
855
769
  updated_at: datetime.datetime
856
770
 
@@ -863,4 +777,6 @@ class DataQARunOut(BaseModel):
863
777
  status: RunStatus
864
778
  approved: bool
865
779
  created_at: datetime.datetime
866
- run_args: typing.Optional[RunArgs]
780
+ run_args: RunArgs | None
781
+
782
+ deleted_at: datetime.datetime | None = None
@@ -11,11 +11,11 @@ DataFrameType = TypeAliasType("DataFrameType", None)
11
11
  if has_pandas:
12
12
  from hirundo._dataframe import pd
13
13
 
14
- DataFrameType = TypeAliasType("DataFrameType", typing.Union[pd.DataFrame, None])
14
+ DataFrameType = TypeAliasType("DataFrameType", pd.DataFrame | None)
15
15
  if has_polars:
16
16
  from hirundo._dataframe import pl
17
17
 
18
- DataFrameType = TypeAliasType("DataFrameType", typing.Union[pl.DataFrame, None])
18
+ DataFrameType = TypeAliasType("DataFrameType", pl.DataFrame | None)
19
19
 
20
20
 
21
21
  T = typing.TypeVar("T")
@@ -32,7 +32,7 @@ class DatasetQAResults(BaseModel, typing.Generic[T]):
32
32
  """
33
33
  A polars/pandas DataFrame containing the results of the data QA run
34
34
  """
35
- object_suspects: typing.Optional[T]
35
+ object_suspects: T | None
36
36
  """
37
37
  A polars/pandas DataFrame containing the object-level results of the data QA run
38
38
  """
hirundo/git.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import datetime
2
2
  import re
3
- import typing
4
3
 
5
4
  import pydantic
6
5
  from pydantic import BaseModel, field_validator
@@ -32,14 +31,14 @@ class GitSSHAuth(BaseModel):
32
31
  """
33
32
  The SSH key for the Git repository
34
33
  """
35
- ssh_password: typing.Optional[str]
34
+ ssh_password: str | None
36
35
  """
37
36
  The password for the SSH key for the Git repository.
38
37
  """
39
38
 
40
39
 
41
40
  class GitRepo(BaseModel):
42
- id: typing.Optional[int] = None
41
+ id: int | None = None
43
42
  """
44
43
  The ID of the Git repository.
45
44
  """
@@ -48,25 +47,25 @@ class GitRepo(BaseModel):
48
47
  """
49
48
  A name to identify the Git repository in the Hirundo system.
50
49
  """
51
- repository_url: typing.Union[str, RepoUrl]
50
+ repository_url: str | RepoUrl
52
51
  """
53
52
  The URL of the Git repository, it should start with `ssh://` or `https://` or be in the form `user@host:path`.
54
53
  If it is in the form `user@host:path`, it will be rewritten to `ssh://user@host/path`.
55
54
  """
56
- organization_id: typing.Optional[int] = None
55
+ organization_id: int | None = None
57
56
  """
58
57
  The ID of the organization that the Git repository belongs to.
59
58
  If not provided, it will be assigned to your default organization.
60
59
  """
61
60
 
62
- plain_auth: typing.Optional[GitPlainAuth] = pydantic.Field(
61
+ plain_auth: GitPlainAuth | None = pydantic.Field(
63
62
  default=None, examples=[None, {"username": "ben", "password": "password"}]
64
63
  )
65
64
  """
66
65
  The plain authentication details for the Git repository.
67
66
  Use this if using a special user with a username and password for authentication.
68
67
  """
69
- ssh_auth: typing.Optional[GitSSHAuth] = pydantic.Field(
68
+ ssh_auth: GitSSHAuth | None = pydantic.Field(
70
69
  default=None,
71
70
  examples=[
72
71
  {
@@ -84,7 +83,7 @@ class GitRepo(BaseModel):
84
83
 
85
84
  @field_validator("repository_url", mode="before", check_fields=True)
86
85
  @classmethod
87
- def check_valid_repository_url(cls, repository_url: typing.Union[str, RepoUrl]):
86
+ def check_valid_repository_url(cls, repository_url: str | RepoUrl):
88
87
  # Check if the URL has the `@` and `:` pattern with a non-numeric section before the next slash
89
88
  match = re.match("([^@]+@[^:]+):([^0-9/][^/]*)/(.+)", str(repository_url))
90
89
  if match: