src-auth-perms-sync 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- src_auth_perms_sync/__init__.py +1 -0
- src_auth_perms_sync/__main__.py +6 -0
- src_auth_perms_sync/cli.py +646 -0
- src_auth_perms_sync/orgs/__init__.py +1 -0
- src_auth_perms_sync/orgs/command.py +7 -0
- src_auth_perms_sync/orgs/queries.py +44 -0
- src_auth_perms_sync/orgs/sync.py +1167 -0
- src_auth_perms_sync/orgs/types.py +103 -0
- src_auth_perms_sync/permissions/__init__.py +1 -0
- src_auth_perms_sync/permissions/apply.py +420 -0
- src_auth_perms_sync/permissions/command.py +918 -0
- src_auth_perms_sync/permissions/full_set.py +880 -0
- src_auth_perms_sync/permissions/mapping.py +627 -0
- src_auth_perms_sync/permissions/maps.py +291 -0
- src_auth_perms_sync/permissions/queries.py +180 -0
- src_auth_perms_sync/permissions/restore.py +913 -0
- src_auth_perms_sync/permissions/snapshot.py +1502 -0
- src_auth_perms_sync/permissions/sourcegraph.py +392 -0
- src_auth_perms_sync/permissions/types.py +116 -0
- src_auth_perms_sync/permissions/workflow.py +526 -0
- src_auth_perms_sync/shared/__init__.py +1 -0
- src_auth_perms_sync/shared/backups.py +119 -0
- src_auth_perms_sync/shared/id_codec.py +67 -0
- src_auth_perms_sync/shared/queries.py +65 -0
- src_auth_perms_sync/shared/run_context.py +34 -0
- src_auth_perms_sync/shared/saml_groups.py +267 -0
- src_auth_perms_sync/shared/site_config.py +366 -0
- src_auth_perms_sync/shared/sourcegraph.py +69 -0
- src_auth_perms_sync/shared/types.py +69 -0
- src_auth_perms_sync-0.2.1.dist-info/METADATA +256 -0
- src_auth_perms_sync-0.2.1.dist-info/RECORD +34 -0
- src_auth_perms_sync-0.2.1.dist-info/WHEEL +4 -0
- src_auth_perms_sync-0.2.1.dist-info/entry_points.txt +2 -0
- src_auth_perms_sync-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1167 @@
|
|
|
1
|
+
"""Sourcegraph organization sync command handler."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import datetime
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import re
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Iterable
|
|
11
|
+
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, cast
|
|
15
|
+
|
|
16
|
+
import src_py_lib as src
|
|
17
|
+
|
|
18
|
+
from ..permissions import apply as permissions_apply
|
|
19
|
+
from ..shared import backups, run_context, saml_groups
|
|
20
|
+
from ..shared import sourcegraph as shared_sourcegraph
|
|
21
|
+
from ..shared import types as shared_types
|
|
22
|
+
from . import queries
|
|
23
|
+
from . import types as organization_types
|
|
24
|
+
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
ORGANIZATION_LOOKUP_BATCH_SIZE: int = 50
|
|
28
|
+
ORGANIZATION_SEARCH_RESULT_LIMIT: int = 100
|
|
29
|
+
ORGANIZATION_MEMBER_PAGE_SIZE: int = 1000
|
|
30
|
+
ORGANIZATION_NAME_MAX_LENGTH: int = 255
|
|
31
|
+
ORGANIZATION_SNAPSHOT_SCHEMA_VERSION: int = 1
|
|
32
|
+
ORGANIZATION_SNAPSHOT_DIFF_SCHEMA_VERSION: int = 1
|
|
33
|
+
|
|
34
|
+
_ORGANIZATION_NAME_PART_RE = re.compile(r"[^A-Za-z0-9]+")
|
|
35
|
+
_ORGANIZATION_NAME_DASH_RUN_RE = re.compile(r"-+")
|
|
36
|
+
_ALREADY_MEMBER_TEXT = "user is already a member of the organization"
|
|
37
|
+
_ORGANIZATION_EXISTS_TEXT = "organization name is already taken"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class _OrganizationSyncState:
|
|
42
|
+
"""Loaded state and planned target for one SAML organization sync."""
|
|
43
|
+
|
|
44
|
+
targets: dict[str, organization_types.TargetOrganization]
|
|
45
|
+
current_user: organization_types.OrgMember
|
|
46
|
+
current_states: dict[str, organization_types.OrganizationState]
|
|
47
|
+
before_snapshot: organization_types.OrganizationSnapshot
|
|
48
|
+
plan: organization_types.OrganizationPlan
|
|
49
|
+
expected_snapshot: organization_types.OrganizationSnapshot
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class _OrganizationApplyResult:
|
|
54
|
+
"""Mutation counts for all organization sync phases."""
|
|
55
|
+
|
|
56
|
+
creates: shared_types.MutationCounts
|
|
57
|
+
additions: shared_types.MutationCounts
|
|
58
|
+
removals: shared_types.MutationCounts
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _load_organization_sync_state(
|
|
62
|
+
client: src.SourcegraphClient,
|
|
63
|
+
saml_groups_attribute_name_by_config_id: dict[str, str],
|
|
64
|
+
parallelism: int,
|
|
65
|
+
command_event: dict[str, Any],
|
|
66
|
+
command_data: run_context.CommandData,
|
|
67
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
68
|
+
) -> _OrganizationSyncState | None:
|
|
69
|
+
"""Discover SAML org targets, live state, and the desired plan."""
|
|
70
|
+
if command_data.auth_providers is None:
|
|
71
|
+
log.info("Querying auth providers from %s ...", client.endpoint)
|
|
72
|
+
providers = shared_sourcegraph.list_auth_providers(client)
|
|
73
|
+
else:
|
|
74
|
+
providers = command_data.auth_providers
|
|
75
|
+
log.info("Reusing %d auth provider(s) loaded earlier in this run.", len(providers))
|
|
76
|
+
saml_providers = [
|
|
77
|
+
provider
|
|
78
|
+
for provider in providers
|
|
79
|
+
if provider["serviceType"] == saml_groups.SAML_SERVICE_TYPE
|
|
80
|
+
]
|
|
81
|
+
if not saml_providers:
|
|
82
|
+
log.warning("No SAML auth providers found — nothing to sync.")
|
|
83
|
+
return None
|
|
84
|
+
attribute_names_by_provider = saml_groups.attribute_names_by_provider_key(
|
|
85
|
+
providers, saml_groups_attribute_name_by_config_id
|
|
86
|
+
)
|
|
87
|
+
log.info("Received %d SAML auth provider(s).", len(saml_providers))
|
|
88
|
+
|
|
89
|
+
targets = _collect_target_organizations(
|
|
90
|
+
providers,
|
|
91
|
+
attribute_names_by_provider,
|
|
92
|
+
client,
|
|
93
|
+
command_data.saml_group_users,
|
|
94
|
+
)
|
|
95
|
+
command_event["target_organizations"] = len(targets)
|
|
96
|
+
command_event["desired_memberships"] = sum(
|
|
97
|
+
len(target.desired_members_by_id) for target in targets.values()
|
|
98
|
+
)
|
|
99
|
+
if not targets:
|
|
100
|
+
log.warning("No SAML group memberships found in user accountData — nothing to sync.")
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
current_user, current_states = _load_current_organization_states(
|
|
104
|
+
client,
|
|
105
|
+
sorted(targets),
|
|
106
|
+
parallelism,
|
|
107
|
+
worker_pool,
|
|
108
|
+
)
|
|
109
|
+
before_snapshot = _snapshot_from_states(client.endpoint, targets, current_states)
|
|
110
|
+
plan = _plan_organization_sync(targets, current_states, current_user)
|
|
111
|
+
expected_states = _expected_states_from_targets(targets, current_states)
|
|
112
|
+
expected_snapshot = _snapshot_from_states(client.endpoint, targets, expected_states)
|
|
113
|
+
return _OrganizationSyncState(
|
|
114
|
+
targets=targets,
|
|
115
|
+
current_user=current_user,
|
|
116
|
+
current_states=current_states,
|
|
117
|
+
before_snapshot=before_snapshot,
|
|
118
|
+
plan=plan,
|
|
119
|
+
expected_snapshot=expected_snapshot,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _log_organization_sync_plan(sync_state: _OrganizationSyncState) -> None:
|
|
124
|
+
log.info(
|
|
125
|
+
"Organization sync plan: create %d org(s), add %d member(s), remove %d member(s).",
|
|
126
|
+
len(sync_state.plan["create_names"]),
|
|
127
|
+
len(sync_state.plan["additions"]),
|
|
128
|
+
len(sync_state.plan["removals"]),
|
|
129
|
+
)
|
|
130
|
+
log.info(
|
|
131
|
+
"Diff (current → desired):\n%s",
|
|
132
|
+
_render_organization_diff(sync_state.before_snapshot, sync_state.expected_snapshot),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _write_organization_snapshot_pair(
|
|
137
|
+
timestamp: str,
|
|
138
|
+
endpoint: str,
|
|
139
|
+
command_name: str,
|
|
140
|
+
before_snapshot: organization_types.OrganizationSnapshot,
|
|
141
|
+
after_snapshot: organization_types.OrganizationSnapshot,
|
|
142
|
+
) -> tuple[Path, Path, Path]:
|
|
143
|
+
before_path = _organization_snapshot_path(timestamp, endpoint, command_name, "before")
|
|
144
|
+
after_path = _organization_snapshot_path(timestamp, endpoint, command_name, "after")
|
|
145
|
+
diff_path = _organization_snapshot_path(timestamp, endpoint, command_name, "diff")
|
|
146
|
+
_write_organization_snapshot(before_path, before_snapshot)
|
|
147
|
+
_write_organization_snapshot(after_path, after_snapshot)
|
|
148
|
+
_write_organization_snapshot_diff(diff_path, before_snapshot, after_snapshot)
|
|
149
|
+
return before_path, after_path, diff_path
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _finish_organization_dry_run(
|
|
153
|
+
endpoint: str,
|
|
154
|
+
timestamp: str,
|
|
155
|
+
sync_state: _OrganizationSyncState,
|
|
156
|
+
do_backup: bool,
|
|
157
|
+
) -> None:
|
|
158
|
+
if do_backup:
|
|
159
|
+
before_path, after_path, diff_path = _write_organization_snapshot_pair(
|
|
160
|
+
timestamp,
|
|
161
|
+
endpoint,
|
|
162
|
+
"sync-saml-orgs-dry-run",
|
|
163
|
+
sync_state.before_snapshot,
|
|
164
|
+
sync_state.expected_snapshot,
|
|
165
|
+
)
|
|
166
|
+
log.info(
|
|
167
|
+
"Wrote dry-run org snapshots: before=%s after=%s diff=%s.",
|
|
168
|
+
before_path,
|
|
169
|
+
after_path,
|
|
170
|
+
diff_path,
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
log.info("Skipped dry-run org snapshots because --no-backup was set.")
|
|
174
|
+
log.info("Dry run complete. Pass --apply to mutate organization membership.")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _write_organization_apply_before_snapshot(
|
|
178
|
+
endpoint: str,
|
|
179
|
+
timestamp: str,
|
|
180
|
+
before_snapshot: organization_types.OrganizationSnapshot,
|
|
181
|
+
) -> Path:
|
|
182
|
+
before_path = _organization_snapshot_path(
|
|
183
|
+
timestamp,
|
|
184
|
+
endpoint,
|
|
185
|
+
"sync-saml-orgs-apply",
|
|
186
|
+
"before",
|
|
187
|
+
)
|
|
188
|
+
_write_organization_snapshot(before_path, before_snapshot)
|
|
189
|
+
log.info("Wrote before org snapshot: %s.", before_path)
|
|
190
|
+
return before_path
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _record_organization_apply_event(
|
|
194
|
+
command_event: dict[str, Any],
|
|
195
|
+
result: _OrganizationApplyResult,
|
|
196
|
+
) -> None:
|
|
197
|
+
command_event["create_succeeded"] = result.creates.succeeded
|
|
198
|
+
command_event["create_failed"] = result.creates.failed
|
|
199
|
+
command_event["create_canceled"] = result.creates.canceled
|
|
200
|
+
command_event["add_succeeded"] = result.additions.succeeded
|
|
201
|
+
command_event["add_failed"] = result.additions.failed
|
|
202
|
+
command_event["add_canceled"] = result.additions.canceled
|
|
203
|
+
command_event["remove_succeeded"] = result.removals.succeeded
|
|
204
|
+
command_event["remove_failed"] = result.removals.failed
|
|
205
|
+
command_event["remove_canceled"] = result.removals.canceled
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _apply_organization_sync(
|
|
209
|
+
client: src.SourcegraphClient,
|
|
210
|
+
sync_state: _OrganizationSyncState,
|
|
211
|
+
parallelism: int,
|
|
212
|
+
command_event: dict[str, Any],
|
|
213
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
214
|
+
) -> _OrganizationApplyResult:
|
|
215
|
+
"""Apply creates first, then recompute and apply member changes."""
|
|
216
|
+
with src.stage("apply"):
|
|
217
|
+
create_counts = _apply_create_organizations(
|
|
218
|
+
client,
|
|
219
|
+
sync_state.plan["create_names"],
|
|
220
|
+
sync_state.current_states,
|
|
221
|
+
sync_state.current_user,
|
|
222
|
+
parallelism,
|
|
223
|
+
worker_pool,
|
|
224
|
+
)
|
|
225
|
+
if create_counts.failed or create_counts.canceled:
|
|
226
|
+
result = _OrganizationApplyResult(
|
|
227
|
+
creates=create_counts,
|
|
228
|
+
additions=shared_types.MutationCounts(),
|
|
229
|
+
removals=shared_types.MutationCounts(),
|
|
230
|
+
)
|
|
231
|
+
_record_organization_apply_event(command_event, result)
|
|
232
|
+
raise SystemExit(1)
|
|
233
|
+
|
|
234
|
+
membership_plan = _plan_organization_sync(
|
|
235
|
+
sync_state.targets,
|
|
236
|
+
sync_state.current_states,
|
|
237
|
+
sync_state.current_user,
|
|
238
|
+
)
|
|
239
|
+
add_counts = _apply_user_changes(
|
|
240
|
+
client,
|
|
241
|
+
membership_plan["additions"],
|
|
242
|
+
sync_state.current_states,
|
|
243
|
+
"add",
|
|
244
|
+
parallelism,
|
|
245
|
+
worker_pool,
|
|
246
|
+
)
|
|
247
|
+
remove_counts = _apply_user_changes(
|
|
248
|
+
client,
|
|
249
|
+
membership_plan["removals"],
|
|
250
|
+
sync_state.current_states,
|
|
251
|
+
"remove",
|
|
252
|
+
parallelism,
|
|
253
|
+
worker_pool,
|
|
254
|
+
)
|
|
255
|
+
return _OrganizationApplyResult(
|
|
256
|
+
creates=create_counts,
|
|
257
|
+
additions=add_counts,
|
|
258
|
+
removals=remove_counts,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _finish_organization_apply_with_backup(
|
|
263
|
+
client: src.SourcegraphClient,
|
|
264
|
+
timestamp: str,
|
|
265
|
+
sync_state: _OrganizationSyncState,
|
|
266
|
+
before_path: Path,
|
|
267
|
+
parallelism: int,
|
|
268
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
269
|
+
) -> None:
|
|
270
|
+
current_user_after, after_states = _load_current_organization_states(
|
|
271
|
+
client,
|
|
272
|
+
sorted(sync_state.targets),
|
|
273
|
+
parallelism,
|
|
274
|
+
worker_pool,
|
|
275
|
+
)
|
|
276
|
+
if current_user_after["id"] != sync_state.current_user["id"]:
|
|
277
|
+
log.warning(
|
|
278
|
+
"Current user changed during org sync (%s → %s); validation still uses org members.",
|
|
279
|
+
sync_state.current_user["username"],
|
|
280
|
+
current_user_after["username"],
|
|
281
|
+
)
|
|
282
|
+
after_snapshot = _snapshot_from_states(client.endpoint, sync_state.targets, after_states)
|
|
283
|
+
after_path, diff_path = _write_organization_after_snapshot(
|
|
284
|
+
timestamp,
|
|
285
|
+
client.endpoint,
|
|
286
|
+
sync_state.before_snapshot,
|
|
287
|
+
after_snapshot,
|
|
288
|
+
)
|
|
289
|
+
log.info("Wrote after org snapshot: %s diff=%s.", after_path, diff_path)
|
|
290
|
+
_validate_organization_sync(after_snapshot, sync_state.expected_snapshot)
|
|
291
|
+
log.info("To inspect the pre-sync org membership state, read:\n %s", before_path)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _write_organization_after_snapshot(
|
|
295
|
+
timestamp: str,
|
|
296
|
+
endpoint: str,
|
|
297
|
+
before_snapshot: organization_types.OrganizationSnapshot,
|
|
298
|
+
after_snapshot: organization_types.OrganizationSnapshot,
|
|
299
|
+
) -> tuple[Path, Path]:
|
|
300
|
+
after_path = _organization_snapshot_path(timestamp, endpoint, "sync-saml-orgs-apply", "after")
|
|
301
|
+
diff_path = _organization_snapshot_path(timestamp, endpoint, "sync-saml-orgs-apply", "diff")
|
|
302
|
+
_write_organization_snapshot(after_path, after_snapshot)
|
|
303
|
+
_write_organization_snapshot_diff(diff_path, before_snapshot, after_snapshot)
|
|
304
|
+
return after_path, diff_path
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _raise_for_failed_organization_sync(result: _OrganizationApplyResult) -> None:
|
|
308
|
+
failed = result.additions.failed + result.removals.failed
|
|
309
|
+
canceled = result.additions.canceled + result.removals.canceled
|
|
310
|
+
if not (failed or canceled):
|
|
311
|
+
return
|
|
312
|
+
log.error(
|
|
313
|
+
"SAML org sync failed: %d mutation(s) failed, %d canceled. "
|
|
314
|
+
"Review the log file and org snapshots, then re-run.",
|
|
315
|
+
failed,
|
|
316
|
+
canceled,
|
|
317
|
+
)
|
|
318
|
+
raise SystemExit(1)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def cmd_sync_saml_organizations(
|
|
322
|
+
client: src.SourcegraphClient,
|
|
323
|
+
*,
|
|
324
|
+
dry_run: bool,
|
|
325
|
+
parallelism: int,
|
|
326
|
+
saml_groups_attribute_name_by_config_id: dict[str, str],
|
|
327
|
+
do_backup: bool,
|
|
328
|
+
command_data: run_context.CommandData | None = None,
|
|
329
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Create/update Sourcegraph orgs from every discovered SAML group.
|
|
332
|
+
|
|
333
|
+
Org names are deterministic and config-free: the Sourcegraph-safe form
|
|
334
|
+
of `<auth provider configID>-<group name>`. Invalid org-name
|
|
335
|
+
characters are converted to `-`; any resulting name collision fails
|
|
336
|
+
before mutation so we never merge unrelated SAML groups accidentally.
|
|
337
|
+
"""
|
|
338
|
+
with src.event(
|
|
339
|
+
"cmd_sync_saml_organizations",
|
|
340
|
+
dry_run=dry_run,
|
|
341
|
+
parallelism=parallelism,
|
|
342
|
+
do_backup=do_backup,
|
|
343
|
+
) as command_event:
|
|
344
|
+
sync_state = _load_organization_sync_state(
|
|
345
|
+
client,
|
|
346
|
+
saml_groups_attribute_name_by_config_id,
|
|
347
|
+
parallelism,
|
|
348
|
+
command_event,
|
|
349
|
+
command_data or run_context.CommandData(),
|
|
350
|
+
worker_pool,
|
|
351
|
+
)
|
|
352
|
+
if sync_state is None:
|
|
353
|
+
return
|
|
354
|
+
|
|
355
|
+
_log_organization_sync_plan(sync_state)
|
|
356
|
+
|
|
357
|
+
timestamp = backups.backup_timestamp()
|
|
358
|
+
if dry_run:
|
|
359
|
+
_finish_organization_dry_run(client.endpoint, timestamp, sync_state, do_backup)
|
|
360
|
+
return
|
|
361
|
+
|
|
362
|
+
before_path: Path | None = None
|
|
363
|
+
if do_backup:
|
|
364
|
+
before_path = _write_organization_apply_before_snapshot(
|
|
365
|
+
client.endpoint,
|
|
366
|
+
timestamp,
|
|
367
|
+
sync_state.before_snapshot,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
apply_result = _apply_organization_sync(
|
|
371
|
+
client,
|
|
372
|
+
sync_state,
|
|
373
|
+
parallelism,
|
|
374
|
+
command_event,
|
|
375
|
+
worker_pool,
|
|
376
|
+
)
|
|
377
|
+
_record_organization_apply_event(command_event, apply_result)
|
|
378
|
+
|
|
379
|
+
if do_backup:
|
|
380
|
+
assert before_path is not None
|
|
381
|
+
_finish_organization_apply_with_backup(
|
|
382
|
+
client,
|
|
383
|
+
timestamp,
|
|
384
|
+
sync_state,
|
|
385
|
+
before_path,
|
|
386
|
+
parallelism,
|
|
387
|
+
worker_pool,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
_raise_for_failed_organization_sync(apply_result)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _collect_target_organizations(
|
|
394
|
+
providers: list[shared_types.AuthProvider],
|
|
395
|
+
attribute_names_by_provider: saml_groups.SamlGroupsAttributeNameByProvider,
|
|
396
|
+
client: src.SourcegraphClient,
|
|
397
|
+
saml_group_users: Iterable[shared_types.SamlGroupUser] | None,
|
|
398
|
+
) -> dict[str, organization_types.TargetOrganization]:
|
|
399
|
+
providers_by_account_key = saml_groups.saml_providers_by_account_key(providers)
|
|
400
|
+
targets: dict[str, organization_types.TargetOrganization] = {}
|
|
401
|
+
collisions: set[str] = set()
|
|
402
|
+
started = time.perf_counter()
|
|
403
|
+
progress_step = 1000
|
|
404
|
+
if saml_group_users is None:
|
|
405
|
+
log.info("Streaming users once and extracting SAML group memberships ...")
|
|
406
|
+
for completed, user in enumerate(shared_sourcegraph.list_users_streaming(client), start=1):
|
|
407
|
+
compact_user = saml_groups.compact_saml_group_user(
|
|
408
|
+
user,
|
|
409
|
+
providers_by_account_key,
|
|
410
|
+
attribute_names_by_provider,
|
|
411
|
+
)
|
|
412
|
+
if compact_user is not None:
|
|
413
|
+
_record_saml_group_user_memberships(targets, collisions, compact_user)
|
|
414
|
+
if completed % progress_step == 0:
|
|
415
|
+
_log_target_collection_progress(completed, started, targets)
|
|
416
|
+
else:
|
|
417
|
+
log.info(
|
|
418
|
+
"Reusing %d precomputed SAML group user(s) loaded earlier in this run ...",
|
|
419
|
+
len(saml_group_users) if isinstance(saml_group_users, list) else 0,
|
|
420
|
+
)
|
|
421
|
+
for completed, user in enumerate(saml_group_users, start=1):
|
|
422
|
+
_record_saml_group_user_memberships(targets, collisions, user)
|
|
423
|
+
if completed % progress_step == 0:
|
|
424
|
+
_log_target_collection_progress(completed, started, targets)
|
|
425
|
+
_raise_for_target_collisions(collisions)
|
|
426
|
+
_log_target_collection_summary(targets)
|
|
427
|
+
return targets
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _record_saml_group_user_memberships(
|
|
431
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
432
|
+
collisions: set[str],
|
|
433
|
+
user: shared_types.SamlGroupUser,
|
|
434
|
+
) -> None:
|
|
435
|
+
for membership in user.saml_group_memberships:
|
|
436
|
+
_record_target_organization_membership(
|
|
437
|
+
targets,
|
|
438
|
+
collisions,
|
|
439
|
+
membership.provider_config_id,
|
|
440
|
+
membership.group_name,
|
|
441
|
+
user,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _record_target_organization_membership(
|
|
446
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
447
|
+
collisions: set[str],
|
|
448
|
+
provider_config_id: str,
|
|
449
|
+
group_name: str,
|
|
450
|
+
user: shared_types.SamlGroupUser,
|
|
451
|
+
) -> None:
|
|
452
|
+
organization_name = organization_name_for_saml_group(provider_config_id, group_name)
|
|
453
|
+
existing_target = targets.get(organization_name)
|
|
454
|
+
if existing_target is not None and (
|
|
455
|
+
existing_target.provider_config_id != provider_config_id
|
|
456
|
+
or existing_target.saml_group != group_name
|
|
457
|
+
):
|
|
458
|
+
collisions.add(
|
|
459
|
+
f"{organization_name!r} maps both "
|
|
460
|
+
f"{existing_target.provider_config_id!r}/{existing_target.saml_group!r} "
|
|
461
|
+
f"and {provider_config_id!r}/{group_name!r}"
|
|
462
|
+
)
|
|
463
|
+
return
|
|
464
|
+
|
|
465
|
+
target = existing_target or organization_types.TargetOrganization(
|
|
466
|
+
name=organization_name,
|
|
467
|
+
provider_config_id=provider_config_id,
|
|
468
|
+
saml_group=group_name,
|
|
469
|
+
)
|
|
470
|
+
target.desired_members_by_id[user.user_id] = {
|
|
471
|
+
"id": user.user_id,
|
|
472
|
+
"username": user.username,
|
|
473
|
+
}
|
|
474
|
+
targets[organization_name] = target
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _log_target_collection_progress(
|
|
478
|
+
completed: int,
|
|
479
|
+
started: float,
|
|
480
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
481
|
+
) -> None:
|
|
482
|
+
elapsed = time.perf_counter() - started
|
|
483
|
+
rate = completed / elapsed if elapsed > 0 else 0.0
|
|
484
|
+
log.info(
|
|
485
|
+
"Processed %d user record(s) for SAML groups in %.0fs "
|
|
486
|
+
"(%.0f users/sec); found %d target org(s).",
|
|
487
|
+
completed,
|
|
488
|
+
elapsed,
|
|
489
|
+
rate,
|
|
490
|
+
len(targets),
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _raise_for_target_collisions(collisions: set[str]) -> None:
|
|
495
|
+
if not collisions:
|
|
496
|
+
return
|
|
497
|
+
bullet = "\n - "
|
|
498
|
+
raise SystemExit(
|
|
499
|
+
"FATAL: SAML group org-name collision(s) after Sourcegraph-safe normalization:"
|
|
500
|
+
+ bullet
|
|
501
|
+
+ bullet.join(sorted(collisions))
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _log_target_collection_summary(
|
|
506
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
507
|
+
) -> None:
|
|
508
|
+
log.info(
|
|
509
|
+
"Found %d SAML-backed target org(s) covering %d desired membership(s).",
|
|
510
|
+
len(targets),
|
|
511
|
+
sum(len(target.desired_members_by_id) for target in targets.values()),
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def organization_name_for_saml_group(provider_config_id: str, group_name: str) -> str:
|
|
516
|
+
provider_part = _organization_name_part(provider_config_id, "auth provider configID")
|
|
517
|
+
group_part = _organization_name_part(group_name, "SAML group name")
|
|
518
|
+
organization_name = f"{provider_part}-{group_part}"
|
|
519
|
+
if len(organization_name) > ORGANIZATION_NAME_MAX_LENGTH:
|
|
520
|
+
raise SystemExit(
|
|
521
|
+
f"FATAL: generated org name for configID={provider_config_id!r} "
|
|
522
|
+
f"group={group_name!r} is {len(organization_name)} characters; "
|
|
523
|
+
f"Sourcegraph org names must be <= {ORGANIZATION_NAME_MAX_LENGTH}."
|
|
524
|
+
)
|
|
525
|
+
return organization_name
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _organization_name_part(value: str, label: str) -> str:
|
|
529
|
+
normalized = _ORGANIZATION_NAME_PART_RE.sub("-", value.strip())
|
|
530
|
+
normalized = _ORGANIZATION_NAME_DASH_RUN_RE.sub("-", normalized).strip("-")
|
|
531
|
+
if not normalized:
|
|
532
|
+
raise SystemExit(
|
|
533
|
+
f"FATAL: {label} {value!r} cannot be converted to a Sourcegraph org-name part."
|
|
534
|
+
)
|
|
535
|
+
return normalized
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _load_current_organization_states(
|
|
539
|
+
client: src.SourcegraphClient,
|
|
540
|
+
organization_names: list[str],
|
|
541
|
+
parallelism: int,
|
|
542
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
543
|
+
) -> tuple[organization_types.OrgMember, dict[str, organization_types.OrganizationState]]:
|
|
544
|
+
states: dict[str, organization_types.OrganizationState] = {}
|
|
545
|
+
current_user: organization_types.OrgMember | None = None
|
|
546
|
+
name_batches = list(_chunks(organization_names, ORGANIZATION_LOOKUP_BATCH_SIZE))
|
|
547
|
+
with src.event(
|
|
548
|
+
"load_current_organization_states",
|
|
549
|
+
organization_count=len(organization_names),
|
|
550
|
+
lookup_batch_count=len(name_batches),
|
|
551
|
+
member_page_size=ORGANIZATION_MEMBER_PAGE_SIZE,
|
|
552
|
+
) as load_event:
|
|
553
|
+
with run_context.thread_pool(parallelism, worker_pool) as executor:
|
|
554
|
+
futures = {
|
|
555
|
+
src.submit_with_log_context(
|
|
556
|
+
executor, _fetch_organization_batch, client, batch
|
|
557
|
+
): batch
|
|
558
|
+
for batch in name_batches
|
|
559
|
+
}
|
|
560
|
+
for future in as_completed(futures):
|
|
561
|
+
result = future.result()
|
|
562
|
+
batch_current_user = result["current_user"]
|
|
563
|
+
if current_user is None:
|
|
564
|
+
current_user = batch_current_user
|
|
565
|
+
elif current_user["id"] != batch_current_user["id"]:
|
|
566
|
+
raise RuntimeError(
|
|
567
|
+
"currentUser changed between organization lookup batches "
|
|
568
|
+
f"({current_user['username']} vs {batch_current_user['username']})"
|
|
569
|
+
)
|
|
570
|
+
states.update(result["states"])
|
|
571
|
+
|
|
572
|
+
existing_states = [state for state in states.values() if state.id is not None]
|
|
573
|
+
load_event["existing_organizations_needing_member_pages"] = len(existing_states)
|
|
574
|
+
if existing_states:
|
|
575
|
+
member_futures = {
|
|
576
|
+
src.submit_with_log_context(
|
|
577
|
+
executor,
|
|
578
|
+
_fetch_all_members,
|
|
579
|
+
client,
|
|
580
|
+
state,
|
|
581
|
+
): state
|
|
582
|
+
for state in existing_states
|
|
583
|
+
}
|
|
584
|
+
for future in as_completed(member_futures):
|
|
585
|
+
state = member_futures[future]
|
|
586
|
+
for member in future.result():
|
|
587
|
+
state.members_by_id[member["id"]] = member
|
|
588
|
+
load_event["existing_organizations"] = sum(1 for state in states.values() if state.id)
|
|
589
|
+
load_event["total_current_members"] = sum(
|
|
590
|
+
len(state.members_by_id) for state in states.values()
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if current_user is None:
|
|
594
|
+
raise RuntimeError("currentUser was not returned while loading organizations")
|
|
595
|
+
return current_user, states
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def _fetch_organization_batch(
|
|
599
|
+
client: src.SourcegraphClient,
|
|
600
|
+
organization_names: list[str],
|
|
601
|
+
) -> organization_types.OrganizationBatchLookup:
|
|
602
|
+
query = _organization_batch_query(len(organization_names))
|
|
603
|
+
variables: dict[str, Any] = {
|
|
604
|
+
"organizationFirst": ORGANIZATION_SEARCH_RESULT_LIMIT,
|
|
605
|
+
"memberFirst": ORGANIZATION_MEMBER_PAGE_SIZE,
|
|
606
|
+
}
|
|
607
|
+
for index, organization_name in enumerate(organization_names):
|
|
608
|
+
variables[f"name{index}"] = organization_name
|
|
609
|
+
with src.event(
|
|
610
|
+
"organization_batch_lookup",
|
|
611
|
+
level="DEBUG",
|
|
612
|
+
organization_count=len(organization_names),
|
|
613
|
+
) as lookup_event:
|
|
614
|
+
data = client.graphql(query, variables)
|
|
615
|
+
current_user = cast(organization_types.OrgMember | None, data.get("currentUser"))
|
|
616
|
+
if current_user is None:
|
|
617
|
+
raise RuntimeError("currentUser is null; SAML org sync requires an authenticated token")
|
|
618
|
+
states: dict[str, organization_types.OrganizationState] = {}
|
|
619
|
+
existing_count = 0
|
|
620
|
+
for index, organization_name in enumerate(organization_names):
|
|
621
|
+
raw_connection = cast(dict[str, Any] | None, data.get(f"organization{index}"))
|
|
622
|
+
if raw_connection is None:
|
|
623
|
+
raise RuntimeError(
|
|
624
|
+
f"organizations lookup alias organization{index} was missing from response"
|
|
625
|
+
)
|
|
626
|
+
raw_organizations = cast(list[dict[str, Any]], raw_connection["nodes"])
|
|
627
|
+
raw_organization = next(
|
|
628
|
+
(
|
|
629
|
+
organization
|
|
630
|
+
for organization in raw_organizations
|
|
631
|
+
if organization.get("name") == organization_name
|
|
632
|
+
),
|
|
633
|
+
None,
|
|
634
|
+
)
|
|
635
|
+
if raw_organization is None:
|
|
636
|
+
total_count = cast(int, raw_connection["totalCount"])
|
|
637
|
+
if total_count >= ORGANIZATION_SEARCH_RESULT_LIMIT:
|
|
638
|
+
log.warning(
|
|
639
|
+
"Org lookup for %s returned %d search result(s) without an exact "
|
|
640
|
+
"match in the first %d; treating it as missing.",
|
|
641
|
+
organization_name,
|
|
642
|
+
total_count,
|
|
643
|
+
ORGANIZATION_SEARCH_RESULT_LIMIT,
|
|
644
|
+
)
|
|
645
|
+
states[organization_name] = organization_types.OrganizationState(
|
|
646
|
+
id=None,
|
|
647
|
+
name=organization_name,
|
|
648
|
+
members_by_id={},
|
|
649
|
+
)
|
|
650
|
+
continue
|
|
651
|
+
existing_count += 1
|
|
652
|
+
states[organization_name] = organization_types.OrganizationState(
|
|
653
|
+
id=cast(str, raw_organization["id"]),
|
|
654
|
+
name=cast(str, raw_organization["name"]),
|
|
655
|
+
members_by_id={},
|
|
656
|
+
)
|
|
657
|
+
lookup_event["existing_organizations"] = existing_count
|
|
658
|
+
return {"current_user": current_user, "states": states}
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def _organization_batch_query(organization_count: int) -> str:
|
|
662
|
+
variable_definitions = ["$organizationFirst: Int!"] + [
|
|
663
|
+
f"$name{index}: String!" for index in range(organization_count)
|
|
664
|
+
]
|
|
665
|
+
fields = ["currentUser { id username }"]
|
|
666
|
+
for index in range(organization_count):
|
|
667
|
+
fields.append(
|
|
668
|
+
f"""
|
|
669
|
+
organization{index}: organizations(first: $organizationFirst, query: $name{index}) {{
|
|
670
|
+
totalCount
|
|
671
|
+
nodes {{
|
|
672
|
+
id
|
|
673
|
+
name
|
|
674
|
+
}}
|
|
675
|
+
}}"""
|
|
676
|
+
)
|
|
677
|
+
return (
|
|
678
|
+
f"query SamlOrganizationLookup({', '.join(variable_definitions)}) {{\n"
|
|
679
|
+
+ "\n".join(fields)
|
|
680
|
+
+ "\n}\n"
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def _fetch_all_members(
|
|
685
|
+
client: src.SourcegraphClient,
|
|
686
|
+
state: organization_types.OrganizationState,
|
|
687
|
+
) -> list[organization_types.OrgMember]:
|
|
688
|
+
if state.id is None:
|
|
689
|
+
return []
|
|
690
|
+
with src.event("organization_members", level="DEBUG", organization_name=state.name):
|
|
691
|
+
return [
|
|
692
|
+
cast(organization_types.OrgMember, node)
|
|
693
|
+
for node in client.stream_connection_nodes(
|
|
694
|
+
queries.QUERY_ORGANIZATION_MEMBERS_PAGE,
|
|
695
|
+
{"id": state.id},
|
|
696
|
+
connection_path=("node", "members"),
|
|
697
|
+
page_size=ORGANIZATION_MEMBER_PAGE_SIZE,
|
|
698
|
+
)
|
|
699
|
+
]
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _plan_organization_sync(
|
|
703
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
704
|
+
current_states: dict[str, organization_types.OrganizationState],
|
|
705
|
+
current_user: organization_types.OrgMember,
|
|
706
|
+
) -> organization_types.OrganizationPlan:
|
|
707
|
+
create_names: list[str] = []
|
|
708
|
+
additions: list[organization_types.OrganizationUserChange] = []
|
|
709
|
+
removals: list[organization_types.OrganizationUserChange] = []
|
|
710
|
+
for organization_name, target in sorted(targets.items()):
|
|
711
|
+
current_state = current_states[organization_name]
|
|
712
|
+
current_members = dict(current_state.members_by_id)
|
|
713
|
+
if current_state.id is None:
|
|
714
|
+
create_names.append(organization_name)
|
|
715
|
+
# createOrganization automatically adds the caller. Account for
|
|
716
|
+
# that before planning adds/removes so we do not call add for the
|
|
717
|
+
# caller or leave them in the org when they are not in the SAML group.
|
|
718
|
+
current_members[current_user["id"]] = current_user
|
|
719
|
+
for user_id, member in sorted(
|
|
720
|
+
target.desired_members_by_id.items(), key=lambda item: item[1]["username"]
|
|
721
|
+
):
|
|
722
|
+
if user_id not in current_members:
|
|
723
|
+
additions.append(
|
|
724
|
+
organization_types.OrganizationUserChange(
|
|
725
|
+
organization_name=organization_name,
|
|
726
|
+
user_id=user_id,
|
|
727
|
+
username=member["username"],
|
|
728
|
+
)
|
|
729
|
+
)
|
|
730
|
+
for user_id, member in sorted(
|
|
731
|
+
current_members.items(), key=lambda item: item[1]["username"]
|
|
732
|
+
):
|
|
733
|
+
if user_id not in target.desired_members_by_id:
|
|
734
|
+
removals.append(
|
|
735
|
+
organization_types.OrganizationUserChange(
|
|
736
|
+
organization_name=organization_name,
|
|
737
|
+
user_id=user_id,
|
|
738
|
+
username=member["username"],
|
|
739
|
+
)
|
|
740
|
+
)
|
|
741
|
+
return {"create_names": create_names, "additions": additions, "removals": removals}
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def _expected_states_from_targets(
|
|
745
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
746
|
+
current_states: dict[str, organization_types.OrganizationState],
|
|
747
|
+
) -> dict[str, organization_types.OrganizationState]:
|
|
748
|
+
return {
|
|
749
|
+
organization_name: organization_types.OrganizationState(
|
|
750
|
+
id=current_states[organization_name].id,
|
|
751
|
+
name=organization_name,
|
|
752
|
+
members_by_id=dict(target.desired_members_by_id),
|
|
753
|
+
)
|
|
754
|
+
for organization_name, target in targets.items()
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def _apply_create_organizations(
|
|
759
|
+
client: src.SourcegraphClient,
|
|
760
|
+
organization_names: list[str],
|
|
761
|
+
current_states: dict[str, organization_types.OrganizationState],
|
|
762
|
+
current_user: organization_types.OrgMember,
|
|
763
|
+
parallelism: int,
|
|
764
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
765
|
+
) -> shared_types.MutationCounts:
|
|
766
|
+
if not organization_names:
|
|
767
|
+
return shared_types.MutationCounts()
|
|
768
|
+
with src.event(
|
|
769
|
+
"apply_create_organizations",
|
|
770
|
+
organization_count=len(organization_names),
|
|
771
|
+
parallelism=parallelism,
|
|
772
|
+
) as batch_event:
|
|
773
|
+
breaker = permissions_apply.CircuitBreaker()
|
|
774
|
+
succeeded = 0
|
|
775
|
+
failed = 0
|
|
776
|
+
canceled = 0
|
|
777
|
+
with run_context.thread_pool(parallelism, worker_pool) as executor:
|
|
778
|
+
futures = {
|
|
779
|
+
src.submit_with_log_context(
|
|
780
|
+
executor,
|
|
781
|
+
_create_organization,
|
|
782
|
+
client,
|
|
783
|
+
organization_name,
|
|
784
|
+
current_user,
|
|
785
|
+
): organization_name
|
|
786
|
+
for organization_name in organization_names
|
|
787
|
+
}
|
|
788
|
+
for future in as_completed(futures):
|
|
789
|
+
organization_name = futures[future]
|
|
790
|
+
try:
|
|
791
|
+
state = future.result()
|
|
792
|
+
current_states[organization_name] = state
|
|
793
|
+
succeeded += 1
|
|
794
|
+
breaker.record(success=True)
|
|
795
|
+
log.info(" OK create org %s.", organization_name)
|
|
796
|
+
except CancelledError:
|
|
797
|
+
canceled += 1
|
|
798
|
+
continue
|
|
799
|
+
except Exception as exception:
|
|
800
|
+
failed += 1
|
|
801
|
+
breaker.record(success=False)
|
|
802
|
+
log.error(" FAIL create org %s: %s", organization_name, exception)
|
|
803
|
+
if breaker.is_open():
|
|
804
|
+
for pending_future in futures:
|
|
805
|
+
if not pending_future.done():
|
|
806
|
+
pending_future.cancel()
|
|
807
|
+
batch_event["succeeded"] = succeeded
|
|
808
|
+
batch_event["failed"] = failed
|
|
809
|
+
batch_event["canceled"] = canceled
|
|
810
|
+
batch_event["circuit_broken"] = breaker.is_open()
|
|
811
|
+
return shared_types.MutationCounts(
|
|
812
|
+
succeeded=succeeded,
|
|
813
|
+
failed=failed,
|
|
814
|
+
canceled=canceled,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def _create_organization(
|
|
819
|
+
client: src.SourcegraphClient,
|
|
820
|
+
organization_name: str,
|
|
821
|
+
current_user: organization_types.OrgMember,
|
|
822
|
+
) -> organization_types.OrganizationState:
|
|
823
|
+
with src.event("create_organization", organization_name=organization_name):
|
|
824
|
+
try:
|
|
825
|
+
data = client.graphql(
|
|
826
|
+
queries.MUTATION_CREATE_ORGANIZATION,
|
|
827
|
+
{"name": organization_name, "displayName": None},
|
|
828
|
+
)
|
|
829
|
+
except src.GraphQLError as exception:
|
|
830
|
+
if _ORGANIZATION_EXISTS_TEXT not in str(exception):
|
|
831
|
+
raise
|
|
832
|
+
log.warning(
|
|
833
|
+
"createOrganization reported %s already exists; re-reading it and continuing.",
|
|
834
|
+
organization_name,
|
|
835
|
+
)
|
|
836
|
+
return _fetch_single_organization_state(client, organization_name)
|
|
837
|
+
created = cast(organization_types.CreatedOrganization, data["createOrganization"])
|
|
838
|
+
return organization_types.OrganizationState(
|
|
839
|
+
id=created["id"],
|
|
840
|
+
name=created["name"],
|
|
841
|
+
members_by_id={current_user["id"]: current_user},
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def _fetch_single_organization_state(
|
|
846
|
+
client: src.SourcegraphClient,
|
|
847
|
+
organization_name: str,
|
|
848
|
+
) -> organization_types.OrganizationState:
|
|
849
|
+
lookup = _fetch_organization_batch(client, [organization_name])
|
|
850
|
+
state = lookup["states"][organization_name]
|
|
851
|
+
if state.id is None:
|
|
852
|
+
raise RuntimeError(
|
|
853
|
+
f"organization {organization_name!r} still does not exist after "
|
|
854
|
+
"createOrganization conflict"
|
|
855
|
+
)
|
|
856
|
+
for member in _fetch_all_members(client, state):
|
|
857
|
+
state.members_by_id[member["id"]] = member
|
|
858
|
+
return state
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
def _apply_user_changes(
|
|
862
|
+
client: src.SourcegraphClient,
|
|
863
|
+
changes: list[organization_types.OrganizationUserChange],
|
|
864
|
+
current_states: dict[str, organization_types.OrganizationState],
|
|
865
|
+
change_kind: organization_types.OrganizationChangeKind,
|
|
866
|
+
parallelism: int,
|
|
867
|
+
worker_pool: ThreadPoolExecutor | None = None,
|
|
868
|
+
) -> shared_types.MutationCounts:
|
|
869
|
+
if not changes:
|
|
870
|
+
return shared_types.MutationCounts()
|
|
871
|
+
with src.event(
|
|
872
|
+
"apply_organization_user_changes",
|
|
873
|
+
change_kind=change_kind,
|
|
874
|
+
change_count=len(changes),
|
|
875
|
+
parallelism=parallelism,
|
|
876
|
+
) as batch_event:
|
|
877
|
+
breaker = permissions_apply.CircuitBreaker()
|
|
878
|
+
succeeded = 0
|
|
879
|
+
failed = 0
|
|
880
|
+
canceled = 0
|
|
881
|
+
with run_context.thread_pool(parallelism, worker_pool) as executor:
|
|
882
|
+
futures = {
|
|
883
|
+
src.submit_with_log_context(
|
|
884
|
+
executor,
|
|
885
|
+
_apply_user_change,
|
|
886
|
+
client,
|
|
887
|
+
change,
|
|
888
|
+
current_states[change.organization_name],
|
|
889
|
+
change_kind,
|
|
890
|
+
): change
|
|
891
|
+
for change in changes
|
|
892
|
+
}
|
|
893
|
+
for future in as_completed(futures):
|
|
894
|
+
change = futures[future]
|
|
895
|
+
try:
|
|
896
|
+
future.result()
|
|
897
|
+
succeeded += 1
|
|
898
|
+
breaker.record(success=True)
|
|
899
|
+
log.info(
|
|
900
|
+
" OK %s %s %s org %s.",
|
|
901
|
+
change_kind,
|
|
902
|
+
change.username,
|
|
903
|
+
"to" if change_kind == "add" else "from",
|
|
904
|
+
change.organization_name,
|
|
905
|
+
)
|
|
906
|
+
except CancelledError:
|
|
907
|
+
canceled += 1
|
|
908
|
+
continue
|
|
909
|
+
except Exception as exception:
|
|
910
|
+
failed += 1
|
|
911
|
+
breaker.record(success=False)
|
|
912
|
+
log.error(
|
|
913
|
+
" FAIL %s %s %s org %s: %s",
|
|
914
|
+
change_kind,
|
|
915
|
+
change.username,
|
|
916
|
+
"to" if change_kind == "add" else "from",
|
|
917
|
+
change.organization_name,
|
|
918
|
+
exception,
|
|
919
|
+
)
|
|
920
|
+
if breaker.is_open():
|
|
921
|
+
for pending_future in futures:
|
|
922
|
+
if not pending_future.done():
|
|
923
|
+
pending_future.cancel()
|
|
924
|
+
batch_event["succeeded"] = succeeded
|
|
925
|
+
batch_event["failed"] = failed
|
|
926
|
+
batch_event["canceled"] = canceled
|
|
927
|
+
batch_event["circuit_broken"] = breaker.is_open()
|
|
928
|
+
return shared_types.MutationCounts(
|
|
929
|
+
succeeded=succeeded,
|
|
930
|
+
failed=failed,
|
|
931
|
+
canceled=canceled,
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
def _apply_user_change(
|
|
936
|
+
client: src.SourcegraphClient,
|
|
937
|
+
change: organization_types.OrganizationUserChange,
|
|
938
|
+
state: organization_types.OrganizationState,
|
|
939
|
+
change_kind: organization_types.OrganizationChangeKind,
|
|
940
|
+
) -> None:
|
|
941
|
+
if state.id is None:
|
|
942
|
+
raise RuntimeError(f"organization {change.organization_name!r} has no ID")
|
|
943
|
+
if change_kind == "add":
|
|
944
|
+
with src.event(
|
|
945
|
+
"add_user_to_organization",
|
|
946
|
+
organization_name=change.organization_name,
|
|
947
|
+
username=change.username,
|
|
948
|
+
):
|
|
949
|
+
try:
|
|
950
|
+
client.graphql(
|
|
951
|
+
queries.MUTATION_ADD_USER_TO_ORGANIZATION,
|
|
952
|
+
{"organization": state.id, "username": change.username},
|
|
953
|
+
)
|
|
954
|
+
except src.GraphQLError as exception:
|
|
955
|
+
if _ALREADY_MEMBER_TEXT in str(exception):
|
|
956
|
+
log.info(
|
|
957
|
+
" Already a member: %s in org %s; treating as success.",
|
|
958
|
+
change.username,
|
|
959
|
+
change.organization_name,
|
|
960
|
+
)
|
|
961
|
+
return
|
|
962
|
+
raise
|
|
963
|
+
return
|
|
964
|
+
with src.event(
|
|
965
|
+
"remove_user_from_organization",
|
|
966
|
+
organization_name=change.organization_name,
|
|
967
|
+
username=change.username,
|
|
968
|
+
):
|
|
969
|
+
client.graphql(
|
|
970
|
+
queries.MUTATION_REMOVE_USER_FROM_ORGANIZATION,
|
|
971
|
+
{"organization": state.id, "user": change.user_id},
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
def _snapshot_from_states(
|
|
976
|
+
endpoint: str,
|
|
977
|
+
targets: dict[str, organization_types.TargetOrganization],
|
|
978
|
+
states: dict[str, organization_types.OrganizationState],
|
|
979
|
+
) -> organization_types.OrganizationSnapshot:
|
|
980
|
+
organizations: dict[str, organization_types.OrganizationSnapshotEntry] = {}
|
|
981
|
+
for organization_name, target in sorted(targets.items()):
|
|
982
|
+
state = states[organization_name]
|
|
983
|
+
organizations[organization_name] = {
|
|
984
|
+
"id": state.id,
|
|
985
|
+
"provider_config_id": target.provider_config_id,
|
|
986
|
+
"saml_group": target.saml_group,
|
|
987
|
+
"members": _sorted_members(state.members_by_id),
|
|
988
|
+
"desired_members": _sorted_members(target.desired_members_by_id),
|
|
989
|
+
}
|
|
990
|
+
return {
|
|
991
|
+
"schema_version": ORGANIZATION_SNAPSHOT_SCHEMA_VERSION,
|
|
992
|
+
"captured_at": datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds"),
|
|
993
|
+
"endpoint": endpoint,
|
|
994
|
+
"stats": {
|
|
995
|
+
"target_organizations": len(organizations),
|
|
996
|
+
"existing_organizations": sum(
|
|
997
|
+
1 for organization in organizations.values() if organization["id"] is not None
|
|
998
|
+
),
|
|
999
|
+
"total_current_members": sum(
|
|
1000
|
+
len(organization["members"]) for organization in organizations.values()
|
|
1001
|
+
),
|
|
1002
|
+
"total_desired_members": sum(
|
|
1003
|
+
len(organization["desired_members"]) for organization in organizations.values()
|
|
1004
|
+
),
|
|
1005
|
+
},
|
|
1006
|
+
"organizations": organizations,
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
def _sorted_members(
|
|
1011
|
+
members_by_id: dict[str, organization_types.OrgMember],
|
|
1012
|
+
) -> list[organization_types.OrgMember]:
|
|
1013
|
+
return sorted(
|
|
1014
|
+
({"id": member["id"], "username": member["username"]} for member in members_by_id.values()),
|
|
1015
|
+
key=lambda member: member["username"],
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def _render_organization_diff(
|
|
1020
|
+
before: organization_types.OrganizationSnapshot, after: organization_types.OrganizationSnapshot
|
|
1021
|
+
) -> str:
|
|
1022
|
+
lines: list[str] = []
|
|
1023
|
+
total_added = 0
|
|
1024
|
+
total_removed = 0
|
|
1025
|
+
before_organizations = before["organizations"]
|
|
1026
|
+
after_organizations = after["organizations"]
|
|
1027
|
+
for organization_name in sorted(set(before_organizations) | set(after_organizations)):
|
|
1028
|
+
before_entry = before_organizations.get(organization_name)
|
|
1029
|
+
after_entry = after_organizations.get(organization_name)
|
|
1030
|
+
before_members = _snapshot_usernames(before_entry["members"] if before_entry else [])
|
|
1031
|
+
after_members = _snapshot_usernames(after_entry["members"] if after_entry else [])
|
|
1032
|
+
added = sorted(after_members - before_members)
|
|
1033
|
+
removed = sorted(before_members - after_members)
|
|
1034
|
+
if not added and not removed:
|
|
1035
|
+
continue
|
|
1036
|
+
lines.append(f"=== {organization_name} ===")
|
|
1037
|
+
if before_entry and before_entry["id"] is None:
|
|
1038
|
+
lines.append(" * organization does not exist yet")
|
|
1039
|
+
if added:
|
|
1040
|
+
lines.append(f" + added ({len(added)}): {', '.join(added)}")
|
|
1041
|
+
if removed:
|
|
1042
|
+
lines.append(f" - removed ({len(removed)}): {', '.join(removed)}")
|
|
1043
|
+
total_added += len(added)
|
|
1044
|
+
total_removed += len(removed)
|
|
1045
|
+
if not lines:
|
|
1046
|
+
return "No changes."
|
|
1047
|
+
lines.append("")
|
|
1048
|
+
lines.append(f"Summary: {total_added} member(s) added, {total_removed} member(s) removed.")
|
|
1049
|
+
return "\n".join(lines)
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _build_organization_snapshot_diff(
|
|
1053
|
+
before: organization_types.OrganizationSnapshot,
|
|
1054
|
+
after: organization_types.OrganizationSnapshot,
|
|
1055
|
+
) -> organization_types.OrganizationSnapshotDiff:
|
|
1056
|
+
organizations: list[organization_types.OrganizationSnapshotDiffEntry] = []
|
|
1057
|
+
members_added = 0
|
|
1058
|
+
members_removed = 0
|
|
1059
|
+
before_organizations = before["organizations"]
|
|
1060
|
+
after_organizations = after["organizations"]
|
|
1061
|
+
for organization_name in sorted(set(before_organizations) | set(after_organizations)):
|
|
1062
|
+
before_entry = before_organizations.get(organization_name)
|
|
1063
|
+
after_entry = after_organizations.get(organization_name)
|
|
1064
|
+
before_members = _snapshot_usernames(before_entry["members"] if before_entry else [])
|
|
1065
|
+
after_members = _snapshot_usernames(after_entry["members"] if after_entry else [])
|
|
1066
|
+
added = sorted(after_members - before_members)
|
|
1067
|
+
removed = sorted(before_members - after_members)
|
|
1068
|
+
created = before_entry is None or before_entry["id"] is None
|
|
1069
|
+
if not added and not removed and not created:
|
|
1070
|
+
continue
|
|
1071
|
+
source_entry = after_entry or before_entry
|
|
1072
|
+
if source_entry is None:
|
|
1073
|
+
continue
|
|
1074
|
+
members_added += len(added)
|
|
1075
|
+
members_removed += len(removed)
|
|
1076
|
+
organizations.append(
|
|
1077
|
+
{
|
|
1078
|
+
"name": organization_name,
|
|
1079
|
+
"id": source_entry["id"],
|
|
1080
|
+
"provider_config_id": source_entry["provider_config_id"],
|
|
1081
|
+
"saml_group": source_entry["saml_group"],
|
|
1082
|
+
"created": created,
|
|
1083
|
+
"before_count": len(before_members),
|
|
1084
|
+
"after_count": len(after_members),
|
|
1085
|
+
"added": added,
|
|
1086
|
+
"removed": removed,
|
|
1087
|
+
}
|
|
1088
|
+
)
|
|
1089
|
+
return {
|
|
1090
|
+
"schema_version": ORGANIZATION_SNAPSHOT_DIFF_SCHEMA_VERSION,
|
|
1091
|
+
"diff_kind": "saml_organizations",
|
|
1092
|
+
"before_captured_at": before["captured_at"],
|
|
1093
|
+
"after_captured_at": after["captured_at"],
|
|
1094
|
+
"endpoint": after["endpoint"],
|
|
1095
|
+
"summary": {
|
|
1096
|
+
"organizations_changed": len(organizations),
|
|
1097
|
+
"organizations_created": sum(
|
|
1098
|
+
1 for organization in organizations if organization["created"]
|
|
1099
|
+
),
|
|
1100
|
+
"members_added": members_added,
|
|
1101
|
+
"members_removed": members_removed,
|
|
1102
|
+
},
|
|
1103
|
+
"organizations": organizations,
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
def _snapshot_usernames(members: list[organization_types.OrgMember]) -> set[str]:
|
|
1108
|
+
return {member["username"] for member in members}
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def _validate_organization_sync(
|
|
1112
|
+
after_snapshot: organization_types.OrganizationSnapshot,
|
|
1113
|
+
expected_snapshot: organization_types.OrganizationSnapshot,
|
|
1114
|
+
) -> None:
|
|
1115
|
+
diff = _render_organization_diff(after_snapshot, expected_snapshot)
|
|
1116
|
+
if diff == "No changes.":
|
|
1117
|
+
log.info("VALIDATION OK: all target org memberships match discovered SAML groups.")
|
|
1118
|
+
return
|
|
1119
|
+
log.warning("VALIDATION: target org memberships differ from desired SAML groups:\n%s", diff)
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def _write_organization_snapshot(
|
|
1123
|
+
path: Path, snapshot: organization_types.OrganizationSnapshot
|
|
1124
|
+
) -> None:
|
|
1125
|
+
with src.event(
|
|
1126
|
+
"disk_io",
|
|
1127
|
+
level="DEBUG",
|
|
1128
|
+
op="write",
|
|
1129
|
+
path=str(path),
|
|
1130
|
+
file_kind="organization_snapshot",
|
|
1131
|
+
) as disk_event:
|
|
1132
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1133
|
+
contents = json.dumps(snapshot, indent=2, sort_keys=True) + "\n"
|
|
1134
|
+
path.write_text(contents)
|
|
1135
|
+
disk_event["bytes"] = len(contents)
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def _write_organization_snapshot_diff(
|
|
1139
|
+
path: Path,
|
|
1140
|
+
before: organization_types.OrganizationSnapshot,
|
|
1141
|
+
after: organization_types.OrganizationSnapshot,
|
|
1142
|
+
) -> None:
|
|
1143
|
+
with src.event(
|
|
1144
|
+
"disk_io",
|
|
1145
|
+
level="DEBUG",
|
|
1146
|
+
op="write",
|
|
1147
|
+
path=str(path),
|
|
1148
|
+
file_kind="organization_snapshot_diff",
|
|
1149
|
+
) as disk_event:
|
|
1150
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
1151
|
+
diff = _build_organization_snapshot_diff(before, after)
|
|
1152
|
+
contents = json.dumps(diff, indent=2, sort_keys=True) + "\n"
|
|
1153
|
+
path.write_text(contents)
|
|
1154
|
+
disk_event["bytes"] = len(contents)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
def _organization_snapshot_path(
|
|
1158
|
+
timestamp: str,
|
|
1159
|
+
endpoint: str,
|
|
1160
|
+
command: str,
|
|
1161
|
+
state: str,
|
|
1162
|
+
) -> Path:
|
|
1163
|
+
return backups.backup_path("saml-organizations", timestamp, endpoint, command, state)
|
|
1164
|
+
|
|
1165
|
+
|
|
1166
|
+
def _chunks(values: list[str], size: int) -> list[list[str]]:
|
|
1167
|
+
return [values[start : start + size] for start in range(0, len(values), size)]
|