src-auth-perms-sync 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. src_auth_perms_sync/__init__.py +1 -0
  2. src_auth_perms_sync/__main__.py +6 -0
  3. src_auth_perms_sync/cli.py +646 -0
  4. src_auth_perms_sync/orgs/__init__.py +1 -0
  5. src_auth_perms_sync/orgs/command.py +7 -0
  6. src_auth_perms_sync/orgs/queries.py +44 -0
  7. src_auth_perms_sync/orgs/sync.py +1167 -0
  8. src_auth_perms_sync/orgs/types.py +103 -0
  9. src_auth_perms_sync/permissions/__init__.py +1 -0
  10. src_auth_perms_sync/permissions/apply.py +420 -0
  11. src_auth_perms_sync/permissions/command.py +918 -0
  12. src_auth_perms_sync/permissions/full_set.py +880 -0
  13. src_auth_perms_sync/permissions/mapping.py +627 -0
  14. src_auth_perms_sync/permissions/maps.py +291 -0
  15. src_auth_perms_sync/permissions/queries.py +180 -0
  16. src_auth_perms_sync/permissions/restore.py +913 -0
  17. src_auth_perms_sync/permissions/snapshot.py +1502 -0
  18. src_auth_perms_sync/permissions/sourcegraph.py +392 -0
  19. src_auth_perms_sync/permissions/types.py +116 -0
  20. src_auth_perms_sync/permissions/workflow.py +526 -0
  21. src_auth_perms_sync/shared/__init__.py +1 -0
  22. src_auth_perms_sync/shared/backups.py +119 -0
  23. src_auth_perms_sync/shared/id_codec.py +67 -0
  24. src_auth_perms_sync/shared/queries.py +65 -0
  25. src_auth_perms_sync/shared/run_context.py +34 -0
  26. src_auth_perms_sync/shared/saml_groups.py +267 -0
  27. src_auth_perms_sync/shared/site_config.py +366 -0
  28. src_auth_perms_sync/shared/sourcegraph.py +69 -0
  29. src_auth_perms_sync/shared/types.py +69 -0
  30. src_auth_perms_sync-0.2.1.dist-info/METADATA +256 -0
  31. src_auth_perms_sync-0.2.1.dist-info/RECORD +34 -0
  32. src_auth_perms_sync-0.2.1.dist-info/WHEEL +4 -0
  33. src_auth_perms_sync-0.2.1.dist-info/entry_points.txt +2 -0
  34. src_auth_perms_sync-0.2.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1167 @@
1
+ """Sourcegraph organization sync command handler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import datetime
6
+ import json
7
+ import logging
8
+ import re
9
+ import time
10
+ from collections.abc import Iterable
11
+ from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ from typing import Any, cast
15
+
16
+ import src_py_lib as src
17
+
18
+ from ..permissions import apply as permissions_apply
19
+ from ..shared import backups, run_context, saml_groups
20
+ from ..shared import sourcegraph as shared_sourcegraph
21
+ from ..shared import types as shared_types
22
+ from . import queries
23
+ from . import types as organization_types
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+ ORGANIZATION_LOOKUP_BATCH_SIZE: int = 50
28
+ ORGANIZATION_SEARCH_RESULT_LIMIT: int = 100
29
+ ORGANIZATION_MEMBER_PAGE_SIZE: int = 1000
30
+ ORGANIZATION_NAME_MAX_LENGTH: int = 255
31
+ ORGANIZATION_SNAPSHOT_SCHEMA_VERSION: int = 1
32
+ ORGANIZATION_SNAPSHOT_DIFF_SCHEMA_VERSION: int = 1
33
+
34
+ _ORGANIZATION_NAME_PART_RE = re.compile(r"[^A-Za-z0-9]+")
35
+ _ORGANIZATION_NAME_DASH_RUN_RE = re.compile(r"-+")
36
+ _ALREADY_MEMBER_TEXT = "user is already a member of the organization"
37
+ _ORGANIZATION_EXISTS_TEXT = "organization name is already taken"
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class _OrganizationSyncState:
42
+ """Loaded state and planned target for one SAML organization sync."""
43
+
44
+ targets: dict[str, organization_types.TargetOrganization]
45
+ current_user: organization_types.OrgMember
46
+ current_states: dict[str, organization_types.OrganizationState]
47
+ before_snapshot: organization_types.OrganizationSnapshot
48
+ plan: organization_types.OrganizationPlan
49
+ expected_snapshot: organization_types.OrganizationSnapshot
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class _OrganizationApplyResult:
54
+ """Mutation counts for all organization sync phases."""
55
+
56
+ creates: shared_types.MutationCounts
57
+ additions: shared_types.MutationCounts
58
+ removals: shared_types.MutationCounts
59
+
60
+
61
+ def _load_organization_sync_state(
62
+ client: src.SourcegraphClient,
63
+ saml_groups_attribute_name_by_config_id: dict[str, str],
64
+ parallelism: int,
65
+ command_event: dict[str, Any],
66
+ command_data: run_context.CommandData,
67
+ worker_pool: ThreadPoolExecutor | None = None,
68
+ ) -> _OrganizationSyncState | None:
69
+ """Discover SAML org targets, live state, and the desired plan."""
70
+ if command_data.auth_providers is None:
71
+ log.info("Querying auth providers from %s ...", client.endpoint)
72
+ providers = shared_sourcegraph.list_auth_providers(client)
73
+ else:
74
+ providers = command_data.auth_providers
75
+ log.info("Reusing %d auth provider(s) loaded earlier in this run.", len(providers))
76
+ saml_providers = [
77
+ provider
78
+ for provider in providers
79
+ if provider["serviceType"] == saml_groups.SAML_SERVICE_TYPE
80
+ ]
81
+ if not saml_providers:
82
+ log.warning("No SAML auth providers found — nothing to sync.")
83
+ return None
84
+ attribute_names_by_provider = saml_groups.attribute_names_by_provider_key(
85
+ providers, saml_groups_attribute_name_by_config_id
86
+ )
87
+ log.info("Received %d SAML auth provider(s).", len(saml_providers))
88
+
89
+ targets = _collect_target_organizations(
90
+ providers,
91
+ attribute_names_by_provider,
92
+ client,
93
+ command_data.saml_group_users,
94
+ )
95
+ command_event["target_organizations"] = len(targets)
96
+ command_event["desired_memberships"] = sum(
97
+ len(target.desired_members_by_id) for target in targets.values()
98
+ )
99
+ if not targets:
100
+ log.warning("No SAML group memberships found in user accountData — nothing to sync.")
101
+ return None
102
+
103
+ current_user, current_states = _load_current_organization_states(
104
+ client,
105
+ sorted(targets),
106
+ parallelism,
107
+ worker_pool,
108
+ )
109
+ before_snapshot = _snapshot_from_states(client.endpoint, targets, current_states)
110
+ plan = _plan_organization_sync(targets, current_states, current_user)
111
+ expected_states = _expected_states_from_targets(targets, current_states)
112
+ expected_snapshot = _snapshot_from_states(client.endpoint, targets, expected_states)
113
+ return _OrganizationSyncState(
114
+ targets=targets,
115
+ current_user=current_user,
116
+ current_states=current_states,
117
+ before_snapshot=before_snapshot,
118
+ plan=plan,
119
+ expected_snapshot=expected_snapshot,
120
+ )
121
+
122
+
123
+ def _log_organization_sync_plan(sync_state: _OrganizationSyncState) -> None:
124
+ log.info(
125
+ "Organization sync plan: create %d org(s), add %d member(s), remove %d member(s).",
126
+ len(sync_state.plan["create_names"]),
127
+ len(sync_state.plan["additions"]),
128
+ len(sync_state.plan["removals"]),
129
+ )
130
+ log.info(
131
+ "Diff (current → desired):\n%s",
132
+ _render_organization_diff(sync_state.before_snapshot, sync_state.expected_snapshot),
133
+ )
134
+
135
+
136
+ def _write_organization_snapshot_pair(
137
+ timestamp: str,
138
+ endpoint: str,
139
+ command_name: str,
140
+ before_snapshot: organization_types.OrganizationSnapshot,
141
+ after_snapshot: organization_types.OrganizationSnapshot,
142
+ ) -> tuple[Path, Path, Path]:
143
+ before_path = _organization_snapshot_path(timestamp, endpoint, command_name, "before")
144
+ after_path = _organization_snapshot_path(timestamp, endpoint, command_name, "after")
145
+ diff_path = _organization_snapshot_path(timestamp, endpoint, command_name, "diff")
146
+ _write_organization_snapshot(before_path, before_snapshot)
147
+ _write_organization_snapshot(after_path, after_snapshot)
148
+ _write_organization_snapshot_diff(diff_path, before_snapshot, after_snapshot)
149
+ return before_path, after_path, diff_path
150
+
151
+
152
+ def _finish_organization_dry_run(
153
+ endpoint: str,
154
+ timestamp: str,
155
+ sync_state: _OrganizationSyncState,
156
+ do_backup: bool,
157
+ ) -> None:
158
+ if do_backup:
159
+ before_path, after_path, diff_path = _write_organization_snapshot_pair(
160
+ timestamp,
161
+ endpoint,
162
+ "sync-saml-orgs-dry-run",
163
+ sync_state.before_snapshot,
164
+ sync_state.expected_snapshot,
165
+ )
166
+ log.info(
167
+ "Wrote dry-run org snapshots: before=%s after=%s diff=%s.",
168
+ before_path,
169
+ after_path,
170
+ diff_path,
171
+ )
172
+ else:
173
+ log.info("Skipped dry-run org snapshots because --no-backup was set.")
174
+ log.info("Dry run complete. Pass --apply to mutate organization membership.")
175
+
176
+
177
+ def _write_organization_apply_before_snapshot(
178
+ endpoint: str,
179
+ timestamp: str,
180
+ before_snapshot: organization_types.OrganizationSnapshot,
181
+ ) -> Path:
182
+ before_path = _organization_snapshot_path(
183
+ timestamp,
184
+ endpoint,
185
+ "sync-saml-orgs-apply",
186
+ "before",
187
+ )
188
+ _write_organization_snapshot(before_path, before_snapshot)
189
+ log.info("Wrote before org snapshot: %s.", before_path)
190
+ return before_path
191
+
192
+
193
+ def _record_organization_apply_event(
194
+ command_event: dict[str, Any],
195
+ result: _OrganizationApplyResult,
196
+ ) -> None:
197
+ command_event["create_succeeded"] = result.creates.succeeded
198
+ command_event["create_failed"] = result.creates.failed
199
+ command_event["create_canceled"] = result.creates.canceled
200
+ command_event["add_succeeded"] = result.additions.succeeded
201
+ command_event["add_failed"] = result.additions.failed
202
+ command_event["add_canceled"] = result.additions.canceled
203
+ command_event["remove_succeeded"] = result.removals.succeeded
204
+ command_event["remove_failed"] = result.removals.failed
205
+ command_event["remove_canceled"] = result.removals.canceled
206
+
207
+
208
+ def _apply_organization_sync(
209
+ client: src.SourcegraphClient,
210
+ sync_state: _OrganizationSyncState,
211
+ parallelism: int,
212
+ command_event: dict[str, Any],
213
+ worker_pool: ThreadPoolExecutor | None = None,
214
+ ) -> _OrganizationApplyResult:
215
+ """Apply creates first, then recompute and apply member changes."""
216
+ with src.stage("apply"):
217
+ create_counts = _apply_create_organizations(
218
+ client,
219
+ sync_state.plan["create_names"],
220
+ sync_state.current_states,
221
+ sync_state.current_user,
222
+ parallelism,
223
+ worker_pool,
224
+ )
225
+ if create_counts.failed or create_counts.canceled:
226
+ result = _OrganizationApplyResult(
227
+ creates=create_counts,
228
+ additions=shared_types.MutationCounts(),
229
+ removals=shared_types.MutationCounts(),
230
+ )
231
+ _record_organization_apply_event(command_event, result)
232
+ raise SystemExit(1)
233
+
234
+ membership_plan = _plan_organization_sync(
235
+ sync_state.targets,
236
+ sync_state.current_states,
237
+ sync_state.current_user,
238
+ )
239
+ add_counts = _apply_user_changes(
240
+ client,
241
+ membership_plan["additions"],
242
+ sync_state.current_states,
243
+ "add",
244
+ parallelism,
245
+ worker_pool,
246
+ )
247
+ remove_counts = _apply_user_changes(
248
+ client,
249
+ membership_plan["removals"],
250
+ sync_state.current_states,
251
+ "remove",
252
+ parallelism,
253
+ worker_pool,
254
+ )
255
+ return _OrganizationApplyResult(
256
+ creates=create_counts,
257
+ additions=add_counts,
258
+ removals=remove_counts,
259
+ )
260
+
261
+
262
+ def _finish_organization_apply_with_backup(
263
+ client: src.SourcegraphClient,
264
+ timestamp: str,
265
+ sync_state: _OrganizationSyncState,
266
+ before_path: Path,
267
+ parallelism: int,
268
+ worker_pool: ThreadPoolExecutor | None = None,
269
+ ) -> None:
270
+ current_user_after, after_states = _load_current_organization_states(
271
+ client,
272
+ sorted(sync_state.targets),
273
+ parallelism,
274
+ worker_pool,
275
+ )
276
+ if current_user_after["id"] != sync_state.current_user["id"]:
277
+ log.warning(
278
+ "Current user changed during org sync (%s → %s); validation still uses org members.",
279
+ sync_state.current_user["username"],
280
+ current_user_after["username"],
281
+ )
282
+ after_snapshot = _snapshot_from_states(client.endpoint, sync_state.targets, after_states)
283
+ after_path, diff_path = _write_organization_after_snapshot(
284
+ timestamp,
285
+ client.endpoint,
286
+ sync_state.before_snapshot,
287
+ after_snapshot,
288
+ )
289
+ log.info("Wrote after org snapshot: %s diff=%s.", after_path, diff_path)
290
+ _validate_organization_sync(after_snapshot, sync_state.expected_snapshot)
291
+ log.info("To inspect the pre-sync org membership state, read:\n %s", before_path)
292
+
293
+
294
+ def _write_organization_after_snapshot(
295
+ timestamp: str,
296
+ endpoint: str,
297
+ before_snapshot: organization_types.OrganizationSnapshot,
298
+ after_snapshot: organization_types.OrganizationSnapshot,
299
+ ) -> tuple[Path, Path]:
300
+ after_path = _organization_snapshot_path(timestamp, endpoint, "sync-saml-orgs-apply", "after")
301
+ diff_path = _organization_snapshot_path(timestamp, endpoint, "sync-saml-orgs-apply", "diff")
302
+ _write_organization_snapshot(after_path, after_snapshot)
303
+ _write_organization_snapshot_diff(diff_path, before_snapshot, after_snapshot)
304
+ return after_path, diff_path
305
+
306
+
307
+ def _raise_for_failed_organization_sync(result: _OrganizationApplyResult) -> None:
308
+ failed = result.additions.failed + result.removals.failed
309
+ canceled = result.additions.canceled + result.removals.canceled
310
+ if not (failed or canceled):
311
+ return
312
+ log.error(
313
+ "SAML org sync failed: %d mutation(s) failed, %d canceled. "
314
+ "Review the log file and org snapshots, then re-run.",
315
+ failed,
316
+ canceled,
317
+ )
318
+ raise SystemExit(1)
319
+
320
+
321
+ def cmd_sync_saml_organizations(
322
+ client: src.SourcegraphClient,
323
+ *,
324
+ dry_run: bool,
325
+ parallelism: int,
326
+ saml_groups_attribute_name_by_config_id: dict[str, str],
327
+ do_backup: bool,
328
+ command_data: run_context.CommandData | None = None,
329
+ worker_pool: ThreadPoolExecutor | None = None,
330
+ ) -> None:
331
+ """Create/update Sourcegraph orgs from every discovered SAML group.
332
+
333
+ Org names are deterministic and config-free: the Sourcegraph-safe form
334
+ of `<auth provider configID>-<group name>`. Invalid org-name
335
+ characters are converted to `-`; any resulting name collision fails
336
+ before mutation so we never merge unrelated SAML groups accidentally.
337
+ """
338
+ with src.event(
339
+ "cmd_sync_saml_organizations",
340
+ dry_run=dry_run,
341
+ parallelism=parallelism,
342
+ do_backup=do_backup,
343
+ ) as command_event:
344
+ sync_state = _load_organization_sync_state(
345
+ client,
346
+ saml_groups_attribute_name_by_config_id,
347
+ parallelism,
348
+ command_event,
349
+ command_data or run_context.CommandData(),
350
+ worker_pool,
351
+ )
352
+ if sync_state is None:
353
+ return
354
+
355
+ _log_organization_sync_plan(sync_state)
356
+
357
+ timestamp = backups.backup_timestamp()
358
+ if dry_run:
359
+ _finish_organization_dry_run(client.endpoint, timestamp, sync_state, do_backup)
360
+ return
361
+
362
+ before_path: Path | None = None
363
+ if do_backup:
364
+ before_path = _write_organization_apply_before_snapshot(
365
+ client.endpoint,
366
+ timestamp,
367
+ sync_state.before_snapshot,
368
+ )
369
+
370
+ apply_result = _apply_organization_sync(
371
+ client,
372
+ sync_state,
373
+ parallelism,
374
+ command_event,
375
+ worker_pool,
376
+ )
377
+ _record_organization_apply_event(command_event, apply_result)
378
+
379
+ if do_backup:
380
+ assert before_path is not None
381
+ _finish_organization_apply_with_backup(
382
+ client,
383
+ timestamp,
384
+ sync_state,
385
+ before_path,
386
+ parallelism,
387
+ worker_pool,
388
+ )
389
+
390
+ _raise_for_failed_organization_sync(apply_result)
391
+
392
+
393
+ def _collect_target_organizations(
394
+ providers: list[shared_types.AuthProvider],
395
+ attribute_names_by_provider: saml_groups.SamlGroupsAttributeNameByProvider,
396
+ client: src.SourcegraphClient,
397
+ saml_group_users: Iterable[shared_types.SamlGroupUser] | None,
398
+ ) -> dict[str, organization_types.TargetOrganization]:
399
+ providers_by_account_key = saml_groups.saml_providers_by_account_key(providers)
400
+ targets: dict[str, organization_types.TargetOrganization] = {}
401
+ collisions: set[str] = set()
402
+ started = time.perf_counter()
403
+ progress_step = 1000
404
+ if saml_group_users is None:
405
+ log.info("Streaming users once and extracting SAML group memberships ...")
406
+ for completed, user in enumerate(shared_sourcegraph.list_users_streaming(client), start=1):
407
+ compact_user = saml_groups.compact_saml_group_user(
408
+ user,
409
+ providers_by_account_key,
410
+ attribute_names_by_provider,
411
+ )
412
+ if compact_user is not None:
413
+ _record_saml_group_user_memberships(targets, collisions, compact_user)
414
+ if completed % progress_step == 0:
415
+ _log_target_collection_progress(completed, started, targets)
416
+ else:
417
+ log.info(
418
+ "Reusing %d precomputed SAML group user(s) loaded earlier in this run ...",
419
+ len(saml_group_users) if isinstance(saml_group_users, list) else 0,
420
+ )
421
+ for completed, user in enumerate(saml_group_users, start=1):
422
+ _record_saml_group_user_memberships(targets, collisions, user)
423
+ if completed % progress_step == 0:
424
+ _log_target_collection_progress(completed, started, targets)
425
+ _raise_for_target_collisions(collisions)
426
+ _log_target_collection_summary(targets)
427
+ return targets
428
+
429
+
430
+ def _record_saml_group_user_memberships(
431
+ targets: dict[str, organization_types.TargetOrganization],
432
+ collisions: set[str],
433
+ user: shared_types.SamlGroupUser,
434
+ ) -> None:
435
+ for membership in user.saml_group_memberships:
436
+ _record_target_organization_membership(
437
+ targets,
438
+ collisions,
439
+ membership.provider_config_id,
440
+ membership.group_name,
441
+ user,
442
+ )
443
+
444
+
445
+ def _record_target_organization_membership(
446
+ targets: dict[str, organization_types.TargetOrganization],
447
+ collisions: set[str],
448
+ provider_config_id: str,
449
+ group_name: str,
450
+ user: shared_types.SamlGroupUser,
451
+ ) -> None:
452
+ organization_name = organization_name_for_saml_group(provider_config_id, group_name)
453
+ existing_target = targets.get(organization_name)
454
+ if existing_target is not None and (
455
+ existing_target.provider_config_id != provider_config_id
456
+ or existing_target.saml_group != group_name
457
+ ):
458
+ collisions.add(
459
+ f"{organization_name!r} maps both "
460
+ f"{existing_target.provider_config_id!r}/{existing_target.saml_group!r} "
461
+ f"and {provider_config_id!r}/{group_name!r}"
462
+ )
463
+ return
464
+
465
+ target = existing_target or organization_types.TargetOrganization(
466
+ name=organization_name,
467
+ provider_config_id=provider_config_id,
468
+ saml_group=group_name,
469
+ )
470
+ target.desired_members_by_id[user.user_id] = {
471
+ "id": user.user_id,
472
+ "username": user.username,
473
+ }
474
+ targets[organization_name] = target
475
+
476
+
477
+ def _log_target_collection_progress(
478
+ completed: int,
479
+ started: float,
480
+ targets: dict[str, organization_types.TargetOrganization],
481
+ ) -> None:
482
+ elapsed = time.perf_counter() - started
483
+ rate = completed / elapsed if elapsed > 0 else 0.0
484
+ log.info(
485
+ "Processed %d user record(s) for SAML groups in %.0fs "
486
+ "(%.0f users/sec); found %d target org(s).",
487
+ completed,
488
+ elapsed,
489
+ rate,
490
+ len(targets),
491
+ )
492
+
493
+
494
+ def _raise_for_target_collisions(collisions: set[str]) -> None:
495
+ if not collisions:
496
+ return
497
+ bullet = "\n - "
498
+ raise SystemExit(
499
+ "FATAL: SAML group org-name collision(s) after Sourcegraph-safe normalization:"
500
+ + bullet
501
+ + bullet.join(sorted(collisions))
502
+ )
503
+
504
+
505
+ def _log_target_collection_summary(
506
+ targets: dict[str, organization_types.TargetOrganization],
507
+ ) -> None:
508
+ log.info(
509
+ "Found %d SAML-backed target org(s) covering %d desired membership(s).",
510
+ len(targets),
511
+ sum(len(target.desired_members_by_id) for target in targets.values()),
512
+ )
513
+
514
+
515
+ def organization_name_for_saml_group(provider_config_id: str, group_name: str) -> str:
516
+ provider_part = _organization_name_part(provider_config_id, "auth provider configID")
517
+ group_part = _organization_name_part(group_name, "SAML group name")
518
+ organization_name = f"{provider_part}-{group_part}"
519
+ if len(organization_name) > ORGANIZATION_NAME_MAX_LENGTH:
520
+ raise SystemExit(
521
+ f"FATAL: generated org name for configID={provider_config_id!r} "
522
+ f"group={group_name!r} is {len(organization_name)} characters; "
523
+ f"Sourcegraph org names must be <= {ORGANIZATION_NAME_MAX_LENGTH}."
524
+ )
525
+ return organization_name
526
+
527
+
528
+ def _organization_name_part(value: str, label: str) -> str:
529
+ normalized = _ORGANIZATION_NAME_PART_RE.sub("-", value.strip())
530
+ normalized = _ORGANIZATION_NAME_DASH_RUN_RE.sub("-", normalized).strip("-")
531
+ if not normalized:
532
+ raise SystemExit(
533
+ f"FATAL: {label} {value!r} cannot be converted to a Sourcegraph org-name part."
534
+ )
535
+ return normalized
536
+
537
+
538
+ def _load_current_organization_states(
539
+ client: src.SourcegraphClient,
540
+ organization_names: list[str],
541
+ parallelism: int,
542
+ worker_pool: ThreadPoolExecutor | None = None,
543
+ ) -> tuple[organization_types.OrgMember, dict[str, organization_types.OrganizationState]]:
544
+ states: dict[str, organization_types.OrganizationState] = {}
545
+ current_user: organization_types.OrgMember | None = None
546
+ name_batches = list(_chunks(organization_names, ORGANIZATION_LOOKUP_BATCH_SIZE))
547
+ with src.event(
548
+ "load_current_organization_states",
549
+ organization_count=len(organization_names),
550
+ lookup_batch_count=len(name_batches),
551
+ member_page_size=ORGANIZATION_MEMBER_PAGE_SIZE,
552
+ ) as load_event:
553
+ with run_context.thread_pool(parallelism, worker_pool) as executor:
554
+ futures = {
555
+ src.submit_with_log_context(
556
+ executor, _fetch_organization_batch, client, batch
557
+ ): batch
558
+ for batch in name_batches
559
+ }
560
+ for future in as_completed(futures):
561
+ result = future.result()
562
+ batch_current_user = result["current_user"]
563
+ if current_user is None:
564
+ current_user = batch_current_user
565
+ elif current_user["id"] != batch_current_user["id"]:
566
+ raise RuntimeError(
567
+ "currentUser changed between organization lookup batches "
568
+ f"({current_user['username']} vs {batch_current_user['username']})"
569
+ )
570
+ states.update(result["states"])
571
+
572
+ existing_states = [state for state in states.values() if state.id is not None]
573
+ load_event["existing_organizations_needing_member_pages"] = len(existing_states)
574
+ if existing_states:
575
+ member_futures = {
576
+ src.submit_with_log_context(
577
+ executor,
578
+ _fetch_all_members,
579
+ client,
580
+ state,
581
+ ): state
582
+ for state in existing_states
583
+ }
584
+ for future in as_completed(member_futures):
585
+ state = member_futures[future]
586
+ for member in future.result():
587
+ state.members_by_id[member["id"]] = member
588
+ load_event["existing_organizations"] = sum(1 for state in states.values() if state.id)
589
+ load_event["total_current_members"] = sum(
590
+ len(state.members_by_id) for state in states.values()
591
+ )
592
+
593
+ if current_user is None:
594
+ raise RuntimeError("currentUser was not returned while loading organizations")
595
+ return current_user, states
596
+
597
+
598
+ def _fetch_organization_batch(
599
+ client: src.SourcegraphClient,
600
+ organization_names: list[str],
601
+ ) -> organization_types.OrganizationBatchLookup:
602
+ query = _organization_batch_query(len(organization_names))
603
+ variables: dict[str, Any] = {
604
+ "organizationFirst": ORGANIZATION_SEARCH_RESULT_LIMIT,
605
+ "memberFirst": ORGANIZATION_MEMBER_PAGE_SIZE,
606
+ }
607
+ for index, organization_name in enumerate(organization_names):
608
+ variables[f"name{index}"] = organization_name
609
+ with src.event(
610
+ "organization_batch_lookup",
611
+ level="DEBUG",
612
+ organization_count=len(organization_names),
613
+ ) as lookup_event:
614
+ data = client.graphql(query, variables)
615
+ current_user = cast(organization_types.OrgMember | None, data.get("currentUser"))
616
+ if current_user is None:
617
+ raise RuntimeError("currentUser is null; SAML org sync requires an authenticated token")
618
+ states: dict[str, organization_types.OrganizationState] = {}
619
+ existing_count = 0
620
+ for index, organization_name in enumerate(organization_names):
621
+ raw_connection = cast(dict[str, Any] | None, data.get(f"organization{index}"))
622
+ if raw_connection is None:
623
+ raise RuntimeError(
624
+ f"organizations lookup alias organization{index} was missing from response"
625
+ )
626
+ raw_organizations = cast(list[dict[str, Any]], raw_connection["nodes"])
627
+ raw_organization = next(
628
+ (
629
+ organization
630
+ for organization in raw_organizations
631
+ if organization.get("name") == organization_name
632
+ ),
633
+ None,
634
+ )
635
+ if raw_organization is None:
636
+ total_count = cast(int, raw_connection["totalCount"])
637
+ if total_count >= ORGANIZATION_SEARCH_RESULT_LIMIT:
638
+ log.warning(
639
+ "Org lookup for %s returned %d search result(s) without an exact "
640
+ "match in the first %d; treating it as missing.",
641
+ organization_name,
642
+ total_count,
643
+ ORGANIZATION_SEARCH_RESULT_LIMIT,
644
+ )
645
+ states[organization_name] = organization_types.OrganizationState(
646
+ id=None,
647
+ name=organization_name,
648
+ members_by_id={},
649
+ )
650
+ continue
651
+ existing_count += 1
652
+ states[organization_name] = organization_types.OrganizationState(
653
+ id=cast(str, raw_organization["id"]),
654
+ name=cast(str, raw_organization["name"]),
655
+ members_by_id={},
656
+ )
657
+ lookup_event["existing_organizations"] = existing_count
658
+ return {"current_user": current_user, "states": states}
659
+
660
+
661
+ def _organization_batch_query(organization_count: int) -> str:
662
+ variable_definitions = ["$organizationFirst: Int!"] + [
663
+ f"$name{index}: String!" for index in range(organization_count)
664
+ ]
665
+ fields = ["currentUser { id username }"]
666
+ for index in range(organization_count):
667
+ fields.append(
668
+ f"""
669
+ organization{index}: organizations(first: $organizationFirst, query: $name{index}) {{
670
+ totalCount
671
+ nodes {{
672
+ id
673
+ name
674
+ }}
675
+ }}"""
676
+ )
677
+ return (
678
+ f"query SamlOrganizationLookup({', '.join(variable_definitions)}) {{\n"
679
+ + "\n".join(fields)
680
+ + "\n}\n"
681
+ )
682
+
683
+
684
+ def _fetch_all_members(
685
+ client: src.SourcegraphClient,
686
+ state: organization_types.OrganizationState,
687
+ ) -> list[organization_types.OrgMember]:
688
+ if state.id is None:
689
+ return []
690
+ with src.event("organization_members", level="DEBUG", organization_name=state.name):
691
+ return [
692
+ cast(organization_types.OrgMember, node)
693
+ for node in client.stream_connection_nodes(
694
+ queries.QUERY_ORGANIZATION_MEMBERS_PAGE,
695
+ {"id": state.id},
696
+ connection_path=("node", "members"),
697
+ page_size=ORGANIZATION_MEMBER_PAGE_SIZE,
698
+ )
699
+ ]
700
+
701
+
702
+ def _plan_organization_sync(
703
+ targets: dict[str, organization_types.TargetOrganization],
704
+ current_states: dict[str, organization_types.OrganizationState],
705
+ current_user: organization_types.OrgMember,
706
+ ) -> organization_types.OrganizationPlan:
707
+ create_names: list[str] = []
708
+ additions: list[organization_types.OrganizationUserChange] = []
709
+ removals: list[organization_types.OrganizationUserChange] = []
710
+ for organization_name, target in sorted(targets.items()):
711
+ current_state = current_states[organization_name]
712
+ current_members = dict(current_state.members_by_id)
713
+ if current_state.id is None:
714
+ create_names.append(organization_name)
715
+ # createOrganization automatically adds the caller. Account for
716
+ # that before planning adds/removes so we do not call add for the
717
+ # caller or leave them in the org when they are not in the SAML group.
718
+ current_members[current_user["id"]] = current_user
719
+ for user_id, member in sorted(
720
+ target.desired_members_by_id.items(), key=lambda item: item[1]["username"]
721
+ ):
722
+ if user_id not in current_members:
723
+ additions.append(
724
+ organization_types.OrganizationUserChange(
725
+ organization_name=organization_name,
726
+ user_id=user_id,
727
+ username=member["username"],
728
+ )
729
+ )
730
+ for user_id, member in sorted(
731
+ current_members.items(), key=lambda item: item[1]["username"]
732
+ ):
733
+ if user_id not in target.desired_members_by_id:
734
+ removals.append(
735
+ organization_types.OrganizationUserChange(
736
+ organization_name=organization_name,
737
+ user_id=user_id,
738
+ username=member["username"],
739
+ )
740
+ )
741
+ return {"create_names": create_names, "additions": additions, "removals": removals}
742
+
743
+
744
+ def _expected_states_from_targets(
745
+ targets: dict[str, organization_types.TargetOrganization],
746
+ current_states: dict[str, organization_types.OrganizationState],
747
+ ) -> dict[str, organization_types.OrganizationState]:
748
+ return {
749
+ organization_name: organization_types.OrganizationState(
750
+ id=current_states[organization_name].id,
751
+ name=organization_name,
752
+ members_by_id=dict(target.desired_members_by_id),
753
+ )
754
+ for organization_name, target in targets.items()
755
+ }
756
+
757
+
758
+ def _apply_create_organizations(
759
+ client: src.SourcegraphClient,
760
+ organization_names: list[str],
761
+ current_states: dict[str, organization_types.OrganizationState],
762
+ current_user: organization_types.OrgMember,
763
+ parallelism: int,
764
+ worker_pool: ThreadPoolExecutor | None = None,
765
+ ) -> shared_types.MutationCounts:
766
+ if not organization_names:
767
+ return shared_types.MutationCounts()
768
+ with src.event(
769
+ "apply_create_organizations",
770
+ organization_count=len(organization_names),
771
+ parallelism=parallelism,
772
+ ) as batch_event:
773
+ breaker = permissions_apply.CircuitBreaker()
774
+ succeeded = 0
775
+ failed = 0
776
+ canceled = 0
777
+ with run_context.thread_pool(parallelism, worker_pool) as executor:
778
+ futures = {
779
+ src.submit_with_log_context(
780
+ executor,
781
+ _create_organization,
782
+ client,
783
+ organization_name,
784
+ current_user,
785
+ ): organization_name
786
+ for organization_name in organization_names
787
+ }
788
+ for future in as_completed(futures):
789
+ organization_name = futures[future]
790
+ try:
791
+ state = future.result()
792
+ current_states[organization_name] = state
793
+ succeeded += 1
794
+ breaker.record(success=True)
795
+ log.info(" OK create org %s.", organization_name)
796
+ except CancelledError:
797
+ canceled += 1
798
+ continue
799
+ except Exception as exception:
800
+ failed += 1
801
+ breaker.record(success=False)
802
+ log.error(" FAIL create org %s: %s", organization_name, exception)
803
+ if breaker.is_open():
804
+ for pending_future in futures:
805
+ if not pending_future.done():
806
+ pending_future.cancel()
807
+ batch_event["succeeded"] = succeeded
808
+ batch_event["failed"] = failed
809
+ batch_event["canceled"] = canceled
810
+ batch_event["circuit_broken"] = breaker.is_open()
811
+ return shared_types.MutationCounts(
812
+ succeeded=succeeded,
813
+ failed=failed,
814
+ canceled=canceled,
815
+ )
816
+
817
+
818
+ def _create_organization(
819
+ client: src.SourcegraphClient,
820
+ organization_name: str,
821
+ current_user: organization_types.OrgMember,
822
+ ) -> organization_types.OrganizationState:
823
+ with src.event("create_organization", organization_name=organization_name):
824
+ try:
825
+ data = client.graphql(
826
+ queries.MUTATION_CREATE_ORGANIZATION,
827
+ {"name": organization_name, "displayName": None},
828
+ )
829
+ except src.GraphQLError as exception:
830
+ if _ORGANIZATION_EXISTS_TEXT not in str(exception):
831
+ raise
832
+ log.warning(
833
+ "createOrganization reported %s already exists; re-reading it and continuing.",
834
+ organization_name,
835
+ )
836
+ return _fetch_single_organization_state(client, organization_name)
837
+ created = cast(organization_types.CreatedOrganization, data["createOrganization"])
838
+ return organization_types.OrganizationState(
839
+ id=created["id"],
840
+ name=created["name"],
841
+ members_by_id={current_user["id"]: current_user},
842
+ )
843
+
844
+
845
+ def _fetch_single_organization_state(
846
+ client: src.SourcegraphClient,
847
+ organization_name: str,
848
+ ) -> organization_types.OrganizationState:
849
+ lookup = _fetch_organization_batch(client, [organization_name])
850
+ state = lookup["states"][organization_name]
851
+ if state.id is None:
852
+ raise RuntimeError(
853
+ f"organization {organization_name!r} still does not exist after "
854
+ "createOrganization conflict"
855
+ )
856
+ for member in _fetch_all_members(client, state):
857
+ state.members_by_id[member["id"]] = member
858
+ return state
859
+
860
+
861
+ def _apply_user_changes(
862
+ client: src.SourcegraphClient,
863
+ changes: list[organization_types.OrganizationUserChange],
864
+ current_states: dict[str, organization_types.OrganizationState],
865
+ change_kind: organization_types.OrganizationChangeKind,
866
+ parallelism: int,
867
+ worker_pool: ThreadPoolExecutor | None = None,
868
+ ) -> shared_types.MutationCounts:
869
+ if not changes:
870
+ return shared_types.MutationCounts()
871
+ with src.event(
872
+ "apply_organization_user_changes",
873
+ change_kind=change_kind,
874
+ change_count=len(changes),
875
+ parallelism=parallelism,
876
+ ) as batch_event:
877
+ breaker = permissions_apply.CircuitBreaker()
878
+ succeeded = 0
879
+ failed = 0
880
+ canceled = 0
881
+ with run_context.thread_pool(parallelism, worker_pool) as executor:
882
+ futures = {
883
+ src.submit_with_log_context(
884
+ executor,
885
+ _apply_user_change,
886
+ client,
887
+ change,
888
+ current_states[change.organization_name],
889
+ change_kind,
890
+ ): change
891
+ for change in changes
892
+ }
893
+ for future in as_completed(futures):
894
+ change = futures[future]
895
+ try:
896
+ future.result()
897
+ succeeded += 1
898
+ breaker.record(success=True)
899
+ log.info(
900
+ " OK %s %s %s org %s.",
901
+ change_kind,
902
+ change.username,
903
+ "to" if change_kind == "add" else "from",
904
+ change.organization_name,
905
+ )
906
+ except CancelledError:
907
+ canceled += 1
908
+ continue
909
+ except Exception as exception:
910
+ failed += 1
911
+ breaker.record(success=False)
912
+ log.error(
913
+ " FAIL %s %s %s org %s: %s",
914
+ change_kind,
915
+ change.username,
916
+ "to" if change_kind == "add" else "from",
917
+ change.organization_name,
918
+ exception,
919
+ )
920
+ if breaker.is_open():
921
+ for pending_future in futures:
922
+ if not pending_future.done():
923
+ pending_future.cancel()
924
+ batch_event["succeeded"] = succeeded
925
+ batch_event["failed"] = failed
926
+ batch_event["canceled"] = canceled
927
+ batch_event["circuit_broken"] = breaker.is_open()
928
+ return shared_types.MutationCounts(
929
+ succeeded=succeeded,
930
+ failed=failed,
931
+ canceled=canceled,
932
+ )
933
+
934
+
935
+ def _apply_user_change(
936
+ client: src.SourcegraphClient,
937
+ change: organization_types.OrganizationUserChange,
938
+ state: organization_types.OrganizationState,
939
+ change_kind: organization_types.OrganizationChangeKind,
940
+ ) -> None:
941
+ if state.id is None:
942
+ raise RuntimeError(f"organization {change.organization_name!r} has no ID")
943
+ if change_kind == "add":
944
+ with src.event(
945
+ "add_user_to_organization",
946
+ organization_name=change.organization_name,
947
+ username=change.username,
948
+ ):
949
+ try:
950
+ client.graphql(
951
+ queries.MUTATION_ADD_USER_TO_ORGANIZATION,
952
+ {"organization": state.id, "username": change.username},
953
+ )
954
+ except src.GraphQLError as exception:
955
+ if _ALREADY_MEMBER_TEXT in str(exception):
956
+ log.info(
957
+ " Already a member: %s in org %s; treating as success.",
958
+ change.username,
959
+ change.organization_name,
960
+ )
961
+ return
962
+ raise
963
+ return
964
+ with src.event(
965
+ "remove_user_from_organization",
966
+ organization_name=change.organization_name,
967
+ username=change.username,
968
+ ):
969
+ client.graphql(
970
+ queries.MUTATION_REMOVE_USER_FROM_ORGANIZATION,
971
+ {"organization": state.id, "user": change.user_id},
972
+ )
973
+
974
+
975
+ def _snapshot_from_states(
976
+ endpoint: str,
977
+ targets: dict[str, organization_types.TargetOrganization],
978
+ states: dict[str, organization_types.OrganizationState],
979
+ ) -> organization_types.OrganizationSnapshot:
980
+ organizations: dict[str, organization_types.OrganizationSnapshotEntry] = {}
981
+ for organization_name, target in sorted(targets.items()):
982
+ state = states[organization_name]
983
+ organizations[organization_name] = {
984
+ "id": state.id,
985
+ "provider_config_id": target.provider_config_id,
986
+ "saml_group": target.saml_group,
987
+ "members": _sorted_members(state.members_by_id),
988
+ "desired_members": _sorted_members(target.desired_members_by_id),
989
+ }
990
+ return {
991
+ "schema_version": ORGANIZATION_SNAPSHOT_SCHEMA_VERSION,
992
+ "captured_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"),
993
+ "endpoint": endpoint,
994
+ "stats": {
995
+ "target_organizations": len(organizations),
996
+ "existing_organizations": sum(
997
+ 1 for organization in organizations.values() if organization["id"] is not None
998
+ ),
999
+ "total_current_members": sum(
1000
+ len(organization["members"]) for organization in organizations.values()
1001
+ ),
1002
+ "total_desired_members": sum(
1003
+ len(organization["desired_members"]) for organization in organizations.values()
1004
+ ),
1005
+ },
1006
+ "organizations": organizations,
1007
+ }
1008
+
1009
+
1010
+ def _sorted_members(
1011
+ members_by_id: dict[str, organization_types.OrgMember],
1012
+ ) -> list[organization_types.OrgMember]:
1013
+ return sorted(
1014
+ ({"id": member["id"], "username": member["username"]} for member in members_by_id.values()),
1015
+ key=lambda member: member["username"],
1016
+ )
1017
+
1018
+
1019
+ def _render_organization_diff(
1020
+ before: organization_types.OrganizationSnapshot, after: organization_types.OrganizationSnapshot
1021
+ ) -> str:
1022
+ lines: list[str] = []
1023
+ total_added = 0
1024
+ total_removed = 0
1025
+ before_organizations = before["organizations"]
1026
+ after_organizations = after["organizations"]
1027
+ for organization_name in sorted(set(before_organizations) | set(after_organizations)):
1028
+ before_entry = before_organizations.get(organization_name)
1029
+ after_entry = after_organizations.get(organization_name)
1030
+ before_members = _snapshot_usernames(before_entry["members"] if before_entry else [])
1031
+ after_members = _snapshot_usernames(after_entry["members"] if after_entry else [])
1032
+ added = sorted(after_members - before_members)
1033
+ removed = sorted(before_members - after_members)
1034
+ if not added and not removed:
1035
+ continue
1036
+ lines.append(f"=== {organization_name} ===")
1037
+ if before_entry and before_entry["id"] is None:
1038
+ lines.append(" * organization does not exist yet")
1039
+ if added:
1040
+ lines.append(f" + added ({len(added)}): {', '.join(added)}")
1041
+ if removed:
1042
+ lines.append(f" - removed ({len(removed)}): {', '.join(removed)}")
1043
+ total_added += len(added)
1044
+ total_removed += len(removed)
1045
+ if not lines:
1046
+ return "No changes."
1047
+ lines.append("")
1048
+ lines.append(f"Summary: {total_added} member(s) added, {total_removed} member(s) removed.")
1049
+ return "\n".join(lines)
1050
+
1051
+
1052
+ def _build_organization_snapshot_diff(
1053
+ before: organization_types.OrganizationSnapshot,
1054
+ after: organization_types.OrganizationSnapshot,
1055
+ ) -> organization_types.OrganizationSnapshotDiff:
1056
+ organizations: list[organization_types.OrganizationSnapshotDiffEntry] = []
1057
+ members_added = 0
1058
+ members_removed = 0
1059
+ before_organizations = before["organizations"]
1060
+ after_organizations = after["organizations"]
1061
+ for organization_name in sorted(set(before_organizations) | set(after_organizations)):
1062
+ before_entry = before_organizations.get(organization_name)
1063
+ after_entry = after_organizations.get(organization_name)
1064
+ before_members = _snapshot_usernames(before_entry["members"] if before_entry else [])
1065
+ after_members = _snapshot_usernames(after_entry["members"] if after_entry else [])
1066
+ added = sorted(after_members - before_members)
1067
+ removed = sorted(before_members - after_members)
1068
+ created = before_entry is None or before_entry["id"] is None
1069
+ if not added and not removed and not created:
1070
+ continue
1071
+ source_entry = after_entry or before_entry
1072
+ if source_entry is None:
1073
+ continue
1074
+ members_added += len(added)
1075
+ members_removed += len(removed)
1076
+ organizations.append(
1077
+ {
1078
+ "name": organization_name,
1079
+ "id": source_entry["id"],
1080
+ "provider_config_id": source_entry["provider_config_id"],
1081
+ "saml_group": source_entry["saml_group"],
1082
+ "created": created,
1083
+ "before_count": len(before_members),
1084
+ "after_count": len(after_members),
1085
+ "added": added,
1086
+ "removed": removed,
1087
+ }
1088
+ )
1089
+ return {
1090
+ "schema_version": ORGANIZATION_SNAPSHOT_DIFF_SCHEMA_VERSION,
1091
+ "diff_kind": "saml_organizations",
1092
+ "before_captured_at": before["captured_at"],
1093
+ "after_captured_at": after["captured_at"],
1094
+ "endpoint": after["endpoint"],
1095
+ "summary": {
1096
+ "organizations_changed": len(organizations),
1097
+ "organizations_created": sum(
1098
+ 1 for organization in organizations if organization["created"]
1099
+ ),
1100
+ "members_added": members_added,
1101
+ "members_removed": members_removed,
1102
+ },
1103
+ "organizations": organizations,
1104
+ }
1105
+
1106
+
1107
+ def _snapshot_usernames(members: list[organization_types.OrgMember]) -> set[str]:
1108
+ return {member["username"] for member in members}
1109
+
1110
+
1111
+ def _validate_organization_sync(
1112
+ after_snapshot: organization_types.OrganizationSnapshot,
1113
+ expected_snapshot: organization_types.OrganizationSnapshot,
1114
+ ) -> None:
1115
+ diff = _render_organization_diff(after_snapshot, expected_snapshot)
1116
+ if diff == "No changes.":
1117
+ log.info("VALIDATION OK: all target org memberships match discovered SAML groups.")
1118
+ return
1119
+ log.warning("VALIDATION: target org memberships differ from desired SAML groups:\n%s", diff)
1120
+
1121
+
1122
+ def _write_organization_snapshot(
1123
+ path: Path, snapshot: organization_types.OrganizationSnapshot
1124
+ ) -> None:
1125
+ with src.event(
1126
+ "disk_io",
1127
+ level="DEBUG",
1128
+ op="write",
1129
+ path=str(path),
1130
+ file_kind="organization_snapshot",
1131
+ ) as disk_event:
1132
+ path.parent.mkdir(parents=True, exist_ok=True)
1133
+ contents = json.dumps(snapshot, indent=2, sort_keys=True) + "\n"
1134
+ path.write_text(contents)
1135
+ disk_event["bytes"] = len(contents)
1136
+
1137
+
1138
+ def _write_organization_snapshot_diff(
1139
+ path: Path,
1140
+ before: organization_types.OrganizationSnapshot,
1141
+ after: organization_types.OrganizationSnapshot,
1142
+ ) -> None:
1143
+ with src.event(
1144
+ "disk_io",
1145
+ level="DEBUG",
1146
+ op="write",
1147
+ path=str(path),
1148
+ file_kind="organization_snapshot_diff",
1149
+ ) as disk_event:
1150
+ path.parent.mkdir(parents=True, exist_ok=True)
1151
+ diff = _build_organization_snapshot_diff(before, after)
1152
+ contents = json.dumps(diff, indent=2, sort_keys=True) + "\n"
1153
+ path.write_text(contents)
1154
+ disk_event["bytes"] = len(contents)
1155
+
1156
+
1157
+ def _organization_snapshot_path(
1158
+ timestamp: str,
1159
+ endpoint: str,
1160
+ command: str,
1161
+ state: str,
1162
+ ) -> Path:
1163
+ return backups.backup_path("saml-organizations", timestamp, endpoint, command, state)
1164
+
1165
+
1166
+ def _chunks(values: list[str], size: int) -> list[list[str]]:
1167
+ return [values[start : start + size] for start in range(0, len(values), size)]