recce-nightly 1.16.0.20250818__py3-none-any.whl → 1.16.0.20250819.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of recce-nightly might be problematic. Click here for more details.

recce/state/cloud.py CHANGED
@@ -51,6 +51,7 @@ class CloudStateLoader(RecceStateLoader):
51
51
  cloud_options=cloud_options,
52
52
  initial_state=initial_state,
53
53
  )
54
+ self.recce_cloud = RecceCloud(token=self.token)
54
55
 
55
56
  def verify(self) -> bool:
56
57
  if self.catalog == "github":
@@ -74,157 +75,393 @@ class CloudStateLoader(RecceStateLoader):
74
75
  'Please provide a share URL in the command argument with option "--share-url <share-url>"'
75
76
  )
76
77
  return False
78
+ elif self.catalog == "snapshot":
79
+ if self.cloud_options.get("api_token") is None:
80
+ self.error_message = RECCE_API_TOKEN_MISSING.error_message
81
+ self.hint_message = RECCE_API_TOKEN_MISSING.hint_message
82
+ return False
83
+ if self.cloud_options.get("snapshot_id") is None:
84
+ self.error_message = "No snapshot ID is provided for the snapshot catalog."
85
+ self.hint_message = (
86
+ 'Please provide a snapshot ID in the command argument with option "--snapshot-id <snapshot-id>"'
87
+ )
88
+ return False
77
89
  return True
78
90
 
79
- def _load_state(self) -> Tuple[RecceState, str]:
80
- return self._load_state_from_cloud()
81
-
82
- def _export_state(self, state: RecceState = None) -> Union[str, None]:
83
- return self._export_state_to_cloud()
84
-
85
91
  def purge(self) -> bool:
86
92
  rc, err_msg = RecceCloudStateManager(self.cloud_options).purge_cloud_state()
87
93
  if err_msg:
88
94
  self.error_message = err_msg
89
95
  return rc
90
96
 
91
- def _load_state_from_cloud(self) -> Tuple[RecceState, str]:
97
+ def _load_state(self) -> Tuple[RecceState, str]:
92
98
  """
93
- Load the state from Recce Cloud.
99
+ Load the state from Recce Cloud based on catalog type.
94
100
 
95
101
  Returns:
96
102
  RecceState: The state object.
97
- str: The etag of the state file.
103
+ str: The etag of the state file (only used for GitHub).
98
104
  """
99
105
  if self.catalog == "github":
100
- if (self.pr_info is None) or (self.pr_info.id is None) or (self.pr_info.repository is None):
101
- raise RecceException("Cannot get the pull request information from GitHub.")
106
+ return self._load_state_from_github()
102
107
  elif self.catalog == "preview":
103
- if self.share_id is None:
104
- raise RecceException("Cannot load the share state from Recce Cloud. No share ID is provided.")
108
+ return self._load_state_from_preview()
109
+ elif self.catalog == "snapshot":
110
+ return self._load_state_from_snapshot(), None
111
+ else:
112
+ raise RecceException(f"Unsupported catalog type: {self.catalog}")
113
+
114
+ def _load_state_from_github(self) -> Tuple[RecceState, str]:
115
+ """Load state from GitHub PR with etag checking."""
116
+ if (self.pr_info is None) or (self.pr_info.id is None) or (self.pr_info.repository is None):
117
+ raise RecceException("Cannot get the pull request information from GitHub.")
105
118
 
106
- logger.debug("Fetching state from Recce Cloud...")
119
+ logger.debug("Fetching GitHub state from Recce Cloud...")
120
+
121
+ # Check metadata and etag for GitHub only
107
122
  metadata = self._get_metadata_from_recce_cloud()
108
- if metadata:
109
- state_etag = metadata.get("etag")
110
- else:
111
- state_etag = None
123
+ state_etag = metadata.get("etag") if metadata else None
124
+
125
+ # Return cached state if etag matches
112
126
  if self.state_etag and state_etag == self.state_etag:
113
127
  return self.state, self.state_etag
114
128
 
115
- return self._load_state_from_recce_cloud(), state_etag
129
+ # Download state from GitHub
130
+ presigned_url = self.recce_cloud.get_presigned_url_by_github_repo(
131
+ method=PresignedUrlMethod.DOWNLOAD,
132
+ pr_id=self.pr_info.id,
133
+ repository=self.pr_info.repository,
134
+ artifact_name=RECCE_STATE_COMPRESSED_FILE,
135
+ )
116
136
 
117
- def _get_metadata_from_recce_cloud(self) -> Union[dict, None]:
118
- recce_cloud = RecceCloud(token=self.token)
119
- return recce_cloud.get_artifact_metadata(pr_info=self.pr_info) if self.pr_info else None
137
+ password = self.cloud_options.get("password")
138
+ if password is None:
139
+ raise RecceException(RECCE_CLOUD_PASSWORD_MISSING.error_message)
120
140
 
121
- def _load_state_from_recce_cloud(self) -> Union[RecceState, None]:
122
- import tempfile
141
+ headers = s3_sse_c_headers(password)
142
+ loaded_state = self._download_state_from_url(presigned_url, SupportedFileTypes.GZIP, headers)
123
143
 
124
- import requests
144
+ # Handle the case where download returns None (404 error)
145
+ if loaded_state is None:
146
+ return None, state_etag
125
147
 
126
- recce_cloud = RecceCloud(token=self.token)
127
- password = None
148
+ return loaded_state, state_etag
128
149
 
129
- if self.catalog == "github":
130
- presigned_url = recce_cloud.get_presigned_url_by_github_repo(
131
- method=PresignedUrlMethod.DOWNLOAD,
132
- pr_id=self.pr_info.id,
133
- repository=self.pr_info.repository,
134
- artifact_name=RECCE_STATE_COMPRESSED_FILE,
135
- )
150
+ def _load_state_from_preview(self) -> Tuple[RecceState, None]:
151
+ """Load state from preview share (no etag checking needed)."""
152
+ if self.share_id is None:
153
+ raise RecceException("Cannot load the share state from Recce Cloud. No share ID is provided.")
136
154
 
137
- password = self.cloud_options.get("password")
138
- if password is None:
139
- raise RecceException(RECCE_CLOUD_PASSWORD_MISSING.error_message)
140
- elif self.catalog == "preview":
141
- share_id = self.cloud_options.get("share_id")
142
- presigned_url = recce_cloud.get_presigned_url_by_share_id(
143
- method=PresignedUrlMethod.DOWNLOAD, share_id=share_id
144
- )
155
+ logger.debug("Fetching preview state from Recce Cloud...")
145
156
 
146
- with tempfile.NamedTemporaryFile() as tmp:
147
- from .cloud import s3_sse_c_headers
157
+ # Download state from preview share
158
+ presigned_url = self.recce_cloud.get_presigned_url_by_share_id(
159
+ method=PresignedUrlMethod.DOWNLOAD, share_id=self.share_id
160
+ )
161
+
162
+ loaded_state = self._download_state_from_url(presigned_url, SupportedFileTypes.FILE)
163
+
164
+ # Handle the case where download returns None (404 error)
165
+ if loaded_state is None:
166
+ return None, None
167
+
168
+ return loaded_state, None
169
+
170
+ def _get_metadata_from_recce_cloud(self) -> Union[dict, None]:
171
+ return self.recce_cloud.get_artifact_metadata(pr_info=self.pr_info) if self.pr_info else None
148
172
 
149
- headers = s3_sse_c_headers(password) if password else None
173
+ def _download_state_from_url(
174
+ self, presigned_url: str, file_type: SupportedFileTypes, headers: dict = None
175
+ ) -> RecceState:
176
+ """Download state file from presigned URL and convert to RecceState."""
177
+ import tempfile
178
+
179
+ import requests
180
+
181
+ with tempfile.NamedTemporaryFile() as tmp:
150
182
  response = requests.get(presigned_url, headers=headers)
183
+
151
184
  if response.status_code == 404:
152
185
  self.error_message = "The state file is not found in Recce Cloud."
153
186
  return None
154
187
  elif response.status_code != 200:
155
188
  self.error_message = response.text
156
- raise RecceException(
157
- f"{response.status_code} Failed to download the state file from Recce Cloud. The password could be wrong."
158
- )
189
+ error_msg = f"{response.status_code} Failed to download the state file from Recce Cloud."
190
+ if headers: # GitHub case with password
191
+ error_msg += " The password could be wrong."
192
+ raise RecceException(error_msg)
193
+
159
194
  with open(tmp.name, "wb") as f:
160
195
  f.write(response.content)
161
196
 
162
- file_type = SupportedFileTypes.GZIP if self.catalog == "github" else SupportedFileTypes.FILE
163
197
  return RecceState.from_file(tmp.name, file_type=file_type)
164
198
 
165
- def _export_state_to_cloud(self) -> Tuple[Union[str, None], str]:
199
+ def _load_state_from_snapshot(self) -> RecceState:
200
+ """
201
+ Load state from snapshot by:
202
+ 1. Get snapshot info
203
+ 2. Download artifacts for both base and current snapshots
204
+ 3. Download recce_state if available, otherwise create empty state with artifacts
205
+ """
206
+ if self.snapshot_id is None:
207
+ raise RecceException("Cannot load the snapshot state from Recce Cloud. No snapshot ID is provided.")
208
+
209
+ # 1. Get snapshot information
210
+ logger.debug(f"Getting snapshot {self.snapshot_id}")
211
+ snapshot = self.recce_cloud.get_snapshot(self.snapshot_id)
212
+
213
+ org_id = snapshot.get("org_id")
214
+ project_id = snapshot.get("project_id")
215
+
216
+ if not org_id or not project_id:
217
+ raise RecceException(f"Snapshot {self.snapshot_id} does not belong to a valid organization or project.")
218
+
219
+ # 2. Download manifests and catalogs for both snapshots
220
+ logger.debug(f"Downloading current snapshot artifacts for {self.snapshot_id}")
221
+ current_artifacts = self._download_snapshot_artifacts(self.recce_cloud, org_id, project_id, self.snapshot_id)
222
+
223
+ logger.debug(f"Downloading base snapshot artifacts for project {project_id}")
224
+ base_artifacts = self._download_base_snapshot_artifacts(self.recce_cloud, org_id, project_id)
225
+
226
+ # 3. Try to download existing recce_state, otherwise create new state
227
+ try:
228
+ logger.debug(f"Downloading recce_state for snapshot {self.snapshot_id}")
229
+ state = self._download_snapshot_recce_state(self.recce_cloud, org_id, project_id, self.snapshot_id)
230
+ except Exception as e:
231
+ logger.debug(f"No existing recce_state found, creating new state: {e}")
232
+ state = RecceState()
233
+
234
+ # Set artifacts regardless of whether we loaded existing state
235
+ state.artifacts.base = base_artifacts
236
+ state.artifacts.current = current_artifacts
237
+
238
+ return state
239
+
240
+ def _download_snapshot_artifacts(self, recce_cloud, org_id: str, project_id: str, snapshot_id: str) -> dict:
241
+ """Download manifest and catalog for a snapshot, return JSON data directly."""
242
+ import requests
243
+
244
+ # Get download URLs
245
+ presigned_urls = recce_cloud.get_download_urls_by_snapshot_id(org_id, project_id, snapshot_id)
246
+
247
+ artifacts = {}
248
+
249
+ # Download manifest
250
+ response = requests.get(presigned_urls["manifest_url"])
251
+ if response.status_code == 200:
252
+ artifacts["manifest"] = response.json()
253
+ else:
254
+ raise RecceException(f"Failed to download manifest for snapshot {snapshot_id}")
255
+
256
+ # Download catalog
257
+ response = requests.get(presigned_urls["catalog_url"])
258
+ if response.status_code == 200:
259
+ artifacts["catalog"] = response.json()
260
+ else:
261
+ raise RecceException(f"Failed to download catalog for snapshot {snapshot_id}")
262
+
263
+ return artifacts
264
+
265
+ def _download_snapshot_recce_state(self, recce_cloud, org_id: str, project_id: str, snapshot_id: str) -> RecceState:
266
+ """Download recce_state for a snapshot."""
267
+ # Get download URLs (now includes recce_state_url)
268
+ presigned_urls = recce_cloud.get_download_urls_by_snapshot_id(org_id, project_id, snapshot_id)
269
+ recce_state_url = presigned_urls.get("recce_state_url")
270
+
271
+ if not recce_state_url:
272
+ raise RecceException(f"No recce_state_url found for snapshot {snapshot_id}")
273
+
274
+ # Reuse the existing download method
275
+ state = self._download_state_from_url(recce_state_url, SupportedFileTypes.FILE)
276
+
277
+ if state is None:
278
+ raise RecceException(f"Failed to download recce_state for snapshot {snapshot_id}")
279
+
280
+ return state
281
+
282
+ def _download_base_snapshot_artifacts(self, recce_cloud, org_id: str, project_id: str) -> dict:
283
+ """Download manifest and catalog for the base snapshot, return JSON data directly."""
284
+ import requests
285
+
286
+ # Get download URLs for base snapshot
287
+ presigned_urls = recce_cloud.get_base_snapshot_download_urls(org_id, project_id)
288
+
289
+ artifacts = {}
290
+
291
+ # Download manifest
292
+ response = requests.get(presigned_urls["manifest_url"])
293
+ if response.status_code == 200:
294
+ artifacts["manifest"] = response.json()
295
+ else:
296
+ raise RecceException(f"Failed to download base snapshot manifest for project {project_id}")
297
+
298
+ # Download catalog
299
+ response = requests.get(presigned_urls["catalog_url"])
300
+ if response.status_code == 200:
301
+ artifacts["catalog"] = response.json()
302
+ else:
303
+ raise RecceException(f"Failed to download base snapshot catalog for project {project_id}")
304
+
305
+ return artifacts
306
+
307
+ def _export_state(self) -> Tuple[Union[str, None], str]:
308
+ """
309
+ Export state to Recce Cloud based on catalog type.
310
+
311
+ Returns:
312
+ str: A message indicating the result of the export operation.
313
+ str: The etag of the exported state file (only used for GitHub).
314
+ """
315
+ logger.info("Store recce state to Recce Cloud")
316
+
166
317
  if self.catalog == "github":
167
- if (self.pr_info is None) or (self.pr_info.id is None) or (self.pr_info.repository is None):
168
- raise RecceException("Cannot get the pull request information from GitHub.")
318
+ return self._export_state_to_github()
169
319
  elif self.catalog == "preview":
170
- pass
320
+ return self._export_state_to_preview()
321
+ elif self.catalog == "snapshot":
322
+ return self._export_state_to_snapshot()
323
+ else:
324
+ raise RecceException(f"Unsupported catalog type: {self.catalog}")
171
325
 
326
+ def _export_state_to_github(self) -> Tuple[Union[str, None], str]:
327
+ """Export state to GitHub PR with metadata and etag."""
328
+ if (self.pr_info is None) or (self.pr_info.id is None) or (self.pr_info.repository is None):
329
+ raise RecceException("Cannot get the pull request information from GitHub.")
330
+
331
+ # Generate metadata for GitHub only
172
332
  check_status = CheckDAO().status()
173
333
  metadata = {
174
334
  "total_checks": check_status.get("total", 0),
175
335
  "approved_checks": check_status.get("approved", 0),
176
336
  }
177
337
 
178
- logger.info("Store recce state to Recce Cloud")
179
- message = self._export_state_to_recce_cloud(metadata=metadata)
180
- metadata = self._get_metadata_from_recce_cloud()
181
- state_etag = metadata.get("etag") if metadata else None
338
+ # Upload to Cloud
339
+ presigned_url = self.recce_cloud.get_presigned_url_by_github_repo(
340
+ method=PresignedUrlMethod.UPLOAD,
341
+ repository=self.pr_info.repository,
342
+ artifact_name=RECCE_STATE_COMPRESSED_FILE,
343
+ pr_id=self.pr_info.id,
344
+ metadata=metadata,
345
+ )
346
+ message = self._upload_state_to_url(
347
+ presigned_url=presigned_url,
348
+ file_type=SupportedFileTypes.GZIP,
349
+ password=self.cloud_options.get("password"),
350
+ metadata=metadata,
351
+ )
352
+
353
+ # Get updated etag after upload
354
+ metadata_response = self._get_metadata_from_recce_cloud()
355
+ state_etag = metadata_response.get("etag") if metadata_response else None
356
+
182
357
  if message:
183
358
  logger.warning(message)
184
359
  return message, state_etag
185
360
 
186
- def _export_state_to_recce_cloud(self, metadata: dict = None) -> Union[str, None]:
361
+ def _export_state_to_preview(self) -> Tuple[Union[str, None], None]:
362
+ """Export state to preview share (no metadata or etag needed)."""
363
+ share_id = self.cloud_options.get("share_id")
364
+ presigned_url = self.recce_cloud.get_presigned_url_by_share_id(
365
+ method=PresignedUrlMethod.UPLOAD,
366
+ share_id=share_id,
367
+ metadata=None,
368
+ )
369
+ message = self._upload_state_to_url(
370
+ presigned_url=presigned_url, file_type=SupportedFileTypes.FILE, password=None, metadata=None
371
+ )
372
+
373
+ if message:
374
+ logger.warning(message)
375
+ return message, None
376
+
377
+ def _export_state_to_snapshot(self) -> Tuple[Union[str, None], None]:
378
+ """Export state to snapshot (upload recce_state with empty artifacts)."""
379
+ if self.snapshot_id is None:
380
+ raise RecceException("Cannot export state to snapshot. No snapshot ID is provided.")
381
+
382
+ # Get snapshot information
383
+ snapshot = self.recce_cloud.get_snapshot(self.snapshot_id)
384
+ org_id = snapshot.get("org_id")
385
+ project_id = snapshot.get("project_id")
386
+
387
+ if not org_id or not project_id:
388
+ raise RecceException(f"Snapshot {self.snapshot_id} does not belong to a valid organization or project.")
389
+
390
+ # Get upload URLs (now includes recce_state_url)
391
+ presigned_urls = self.recce_cloud.get_upload_urls_by_snapshot_id(org_id, project_id, self.snapshot_id)
392
+ recce_state_url = presigned_urls.get("recce_state_url")
393
+
394
+ if not recce_state_url:
395
+ raise RecceException(f"No recce_state_url found for snapshot {self.snapshot_id}")
396
+
397
+ # Create a copy of the state with empty artifacts for upload
398
+ upload_state = RecceState()
399
+ upload_state.runs = self.state.runs.copy() if self.state.runs else []
400
+ upload_state.checks = self.state.checks.copy() if self.state.checks else []
401
+ # Keep artifacts empty (don't copy self.state.artifacts)
402
+
403
+ # Upload the state with empty artifacts
404
+ message = self._upload_state_to_url(
405
+ presigned_url=recce_state_url,
406
+ file_type=SupportedFileTypes.FILE,
407
+ password=None,
408
+ metadata=None,
409
+ state=upload_state,
410
+ )
411
+
412
+ if message:
413
+ logger.warning(message)
414
+ return message, None
415
+
416
+ def _upload_state_to_url(
417
+ self,
418
+ presigned_url: str,
419
+ file_type: SupportedFileTypes,
420
+ password: str = None,
421
+ metadata: dict = None,
422
+ state: RecceState = None,
423
+ ) -> Union[str, None]:
424
+ """Upload state file to presigned URL."""
187
425
  import tempfile
188
426
 
189
427
  import requests
190
428
 
191
- if self.catalog == "github":
192
- presigned_url = RecceCloud(token=self.token).get_presigned_url_by_github_repo(
193
- method=PresignedUrlMethod.UPLOAD,
194
- repository=self.pr_info.repository,
195
- artifact_name=RECCE_STATE_COMPRESSED_FILE,
196
- pr_id=self.pr_info.id,
197
- metadata=metadata,
198
- )
199
- elif self.catalog == "preview":
200
- share_id = self.cloud_options.get("share_id")
201
- presigned_url = RecceCloud(token=self.token).get_presigned_url_by_share_id(
202
- method=PresignedUrlMethod.UPLOAD,
203
- share_id=share_id,
204
- metadata=metadata,
205
- )
206
- compress_passwd = self.cloud_options.get("password")
207
- if compress_passwd:
208
- headers = s3_sse_c_headers(compress_passwd)
209
- else:
210
- headers = {}
429
+ # Use provided state or default to self.state
430
+ upload_state = state or self.state
211
431
 
432
+ # Prepare headers
433
+ headers = {}
434
+ if password:
435
+ headers.update(s3_sse_c_headers(password))
212
436
  if metadata:
213
437
  headers["x-amz-tagging"] = urlencode(metadata)
438
+
214
439
  with tempfile.NamedTemporaryFile() as tmp:
215
- if self.catalog == "github":
216
- file_type = SupportedFileTypes.GZIP
217
- elif self.catalog == "preview":
218
- file_type = SupportedFileTypes.FILE
219
- self._export_state_to_file(tmp.name, file_type=file_type)
440
+ # Use the specified state to export to file
441
+ json_data = upload_state.to_json()
442
+ io = file_io_factory(file_type)
443
+ io.write(tmp.name, json_data)
220
444
 
221
445
  with open(tmp.name, "rb") as fd:
222
446
  response = requests.put(presigned_url, data=fd.read(), headers=headers)
447
+
223
448
  if response.status_code not in [200, 204]:
224
449
  self.error_message = response.text
225
450
  return "Failed to upload the state file to Recce Cloud. Reason: " + response.text
451
+
226
452
  return None
227
453
 
454
+ def check_conflict(self) -> bool:
455
+ if self.catalog != "github":
456
+ return False
457
+
458
+ metadata = self._get_metadata_from_recce_cloud()
459
+ if not metadata:
460
+ return False
461
+
462
+ state_etag = metadata.get("etag")
463
+ return state_etag != self.state_etag
464
+
228
465
 
229
466
  class RecceCloudStateManager:
230
467
  error_message: str
@@ -35,8 +35,9 @@ class RecceStateLoader(ABC):
35
35
  self.state_lock = threading.Lock()
36
36
  self.state_etag = None
37
37
  self.pr_info = None
38
- self.catalog: Literal["github", "preview"] = "github"
38
+ self.catalog: Literal["github", "preview", "snapshot"] = "github"
39
39
  self.share_id = None
40
+ self.snapshot_id = None
40
41
 
41
42
  if self.cloud_mode:
42
43
  if self.cloud_options.get("github_token"):
@@ -47,11 +48,19 @@ class RecceStateLoader(ABC):
47
48
  if self.pr_info.id is None:
48
49
  raise RecceException("Cannot get the pull request information from GitHub.")
49
50
  elif self.cloud_options.get("api_token"):
50
- self.catalog = "preview"
51
- self.share_id = self.cloud_options.get("share_id")
51
+ if self.cloud_options.get("snapshot_id"):
52
+ self.catalog = "snapshot"
53
+ self.snapshot_id = self.cloud_options.get("snapshot_id")
54
+ else:
55
+ self.catalog = "preview"
56
+ self.share_id = self.cloud_options.get("share_id")
52
57
  else:
53
58
  raise RecceException(RECCE_CLOUD_TOKEN_MISSING.error_message)
54
59
 
60
+ @property
61
+ def token(self):
62
+ return self.cloud_options.get("github_token") or self.cloud_options.get("api_token")
63
+
55
64
  @abstractmethod
56
65
  def verify(self) -> bool:
57
66
  """
@@ -61,10 +70,6 @@ class RecceStateLoader(ABC):
61
70
  """
62
71
  raise NotImplementedError("Subclasses must implement this method.")
63
72
 
64
- @property
65
- def token(self):
66
- return self.cloud_options.get("github_token") or self.cloud_options.get("api_token")
67
-
68
73
  @property
69
74
  def error_and_hint(self) -> (Union[str, None], Union[str, None]):
70
75
  return self.error_message, self.hint_message
@@ -118,7 +123,7 @@ class RecceStateLoader(ABC):
118
123
  return message
119
124
 
120
125
  @abstractmethod
121
- def _export_state(self, state: RecceState = None) -> Tuple[Union[str, None], str]:
126
+ def _export_state(self) -> Tuple[Union[str, None], str]:
122
127
  """
123
128
  Export the current Recce state to a file or cloud storage.
124
129
  Returns:
@@ -143,17 +148,9 @@ class RecceStateLoader(ABC):
143
148
  return new_state
144
149
 
145
150
  def check_conflict(self) -> bool:
146
- if not self.cloud_mode:
147
- return False
148
-
149
- metadata = self._get_metadata_from_recce_cloud()
150
- if not metadata:
151
- return False
152
-
153
- state_etag = metadata.get("etag")
154
- return state_etag != self.state_etag
151
+ return False
155
152
 
156
- def info(self):
153
+ def info(self) -> dict:
157
154
  if self.state is None:
158
155
  self.error_message = "No state is loaded."
159
156
  return None
recce/util/lineage.py CHANGED
@@ -7,19 +7,17 @@ def find_upstream(node_ids: Iterable, parent_map):
7
7
  visited = set()
8
8
  upstream = set()
9
9
 
10
- def dfs(current):
10
+ stack = list(node_ids)
11
+ while stack:
12
+ current = stack.pop()
11
13
  if current in visited:
12
- return
14
+ continue
13
15
  visited.add(current)
14
-
15
16
  parents = parent_map.get(current, [])
16
17
  for parent in parents:
17
- upstream.add(parent)
18
- dfs(parent)
19
-
20
- for node_id in node_ids:
21
- dfs(node_id)
22
-
18
+ if parent not in upstream:
19
+ upstream.add(parent)
20
+ stack.append(parent)
23
21
  return upstream
24
22
 
25
23
 
@@ -27,19 +25,17 @@ def find_downstream(node_ids: Iterable, child_map):
27
25
  visited = set()
28
26
  downstream = set()
29
27
 
30
- def dfs(current):
28
+ stack = list(node_ids)
29
+ while stack:
30
+ current = stack.pop()
31
31
  if current in visited:
32
- return
32
+ continue
33
33
  visited.add(current)
34
-
35
34
  children = child_map.get(current, [])
36
35
  for child in children:
37
- downstream.add(child)
38
- dfs(child)
39
-
40
- for node_id in node_ids:
41
- dfs(node_id)
42
-
36
+ if child not in downstream:
37
+ downstream.add(child)
38
+ stack.append(child)
43
39
  return downstream
44
40
 
45
41