kleinkram 0.43.2.dev20250331124109__py3-none-any.whl → 0.58.0.dev20260110152317__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. kleinkram/api/client.py +6 -18
  2. kleinkram/api/deser.py +152 -1
  3. kleinkram/api/file_transfer.py +202 -101
  4. kleinkram/api/pagination.py +11 -2
  5. kleinkram/api/query.py +10 -10
  6. kleinkram/api/routes.py +192 -59
  7. kleinkram/auth.py +108 -7
  8. kleinkram/cli/_action.py +131 -0
  9. kleinkram/cli/_download.py +8 -19
  10. kleinkram/cli/_endpoint.py +2 -4
  11. kleinkram/cli/_file.py +6 -18
  12. kleinkram/cli/_file_validator.py +125 -0
  13. kleinkram/cli/_list.py +5 -15
  14. kleinkram/cli/_mission.py +24 -28
  15. kleinkram/cli/_project.py +10 -26
  16. kleinkram/cli/_run.py +220 -0
  17. kleinkram/cli/_upload.py +58 -26
  18. kleinkram/cli/_verify.py +59 -16
  19. kleinkram/cli/app.py +56 -17
  20. kleinkram/cli/error_handling.py +1 -3
  21. kleinkram/config.py +6 -21
  22. kleinkram/core.py +53 -43
  23. kleinkram/errors.py +12 -0
  24. kleinkram/models.py +51 -1
  25. kleinkram/printing.py +229 -18
  26. kleinkram/utils.py +10 -24
  27. kleinkram/wrappers.py +54 -30
  28. {kleinkram-0.43.2.dev20250331124109.dist-info → kleinkram-0.58.0.dev20260110152317.dist-info}/METADATA +6 -4
  29. kleinkram-0.58.0.dev20260110152317.dist-info/RECORD +53 -0
  30. {kleinkram-0.43.2.dev20250331124109.dist-info → kleinkram-0.58.0.dev20260110152317.dist-info}/WHEEL +1 -1
  31. {kleinkram-0.43.2.dev20250331124109.dist-info → kleinkram-0.58.0.dev20260110152317.dist-info}/top_level.txt +0 -1
  32. {testing → tests}/backend_fixtures.py +27 -3
  33. tests/conftest.py +1 -1
  34. tests/generate_test_data.py +314 -0
  35. tests/test_config.py +2 -6
  36. tests/test_core.py +11 -31
  37. tests/test_end_to_end.py +3 -5
  38. tests/test_fixtures.py +3 -5
  39. tests/test_printing.py +9 -11
  40. tests/test_utils.py +1 -3
  41. tests/test_wrappers.py +9 -27
  42. kleinkram-0.43.2.dev20250331124109.dist-info/RECORD +0 -50
  43. testing/__init__.py +0 -0
  44. {kleinkram-0.43.2.dev20250331124109.dist-info → kleinkram-0.58.0.dev20260110152317.dist-info}/entry_points.txt +0 -0
kleinkram/api/routes.py CHANGED
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ import tempfile
4
5
  from enum import Enum
6
+ from pathlib import Path
5
7
  from typing import Any
6
8
  from typing import Dict
7
9
  from typing import Generator
@@ -12,6 +14,7 @@ from typing import Tuple
12
14
  from uuid import UUID
13
15
 
14
16
  import httpx
17
+ import typer
15
18
 
16
19
  import kleinkram.errors
17
20
  from kleinkram._version import __version__
@@ -20,13 +23,17 @@ from kleinkram.api.client import AuthenticatedClient
20
23
  from kleinkram.api.deser import FileObject
21
24
  from kleinkram.api.deser import MissionObject
22
25
  from kleinkram.api.deser import ProjectObject
26
+ from kleinkram.api.deser import RunObject
27
+ from kleinkram.api.deser import _parse_action_template
23
28
  from kleinkram.api.deser import _parse_file
24
29
  from kleinkram.api.deser import _parse_mission
25
30
  from kleinkram.api.deser import _parse_project
31
+ from kleinkram.api.deser import _parse_run
26
32
  from kleinkram.api.pagination import paginated_request
27
33
  from kleinkram.api.query import FileQuery
28
34
  from kleinkram.api.query import MissionQuery
29
35
  from kleinkram.api.query import ProjectQuery
36
+ from kleinkram.api.query import RunQuery
30
37
  from kleinkram.api.query import file_query_is_unique
31
38
  from kleinkram.api.query import mission_query_is_unique
32
39
  from kleinkram.api.query import project_query_is_unique
@@ -38,12 +45,17 @@ from kleinkram.errors import InvalidMissionQuery
38
45
  from kleinkram.errors import InvalidProjectQuery
39
46
  from kleinkram.errors import MissionExists
40
47
  from kleinkram.errors import MissionNotFound
48
+ from kleinkram.errors import MissionValidationError
41
49
  from kleinkram.errors import ProjectExists
42
50
  from kleinkram.errors import ProjectNotFound
51
+ from kleinkram.errors import ProjectValidationError
52
+ from kleinkram.models import ActionTemplate
43
53
  from kleinkram.models import File
44
54
  from kleinkram.models import Mission
45
55
  from kleinkram.models import Project
56
+ from kleinkram.models import Run
46
57
  from kleinkram.utils import is_valid_uuid4
58
+ from kleinkram.utils import split_args
47
59
 
48
60
  __all__ = [
49
61
  "_get_api_version",
@@ -79,6 +91,8 @@ PROJECT_ENDPOINT = "/projects"
79
91
 
80
92
  TAG_TYPE_BY_NAME = "/tag/filtered"
81
93
 
94
+ ACTION_ENDPOINT = "/action"
95
+
82
96
 
83
97
  class Params(str, Enum):
84
98
  FILE_PATTERNS = "filePatterns"
@@ -137,9 +151,7 @@ def get_files(
137
151
  max_entries: Optional[int] = None,
138
152
  ) -> Generator[File, None, None]:
139
153
  params = _file_query_to_params(file_query)
140
- response_stream = paginated_request(
141
- client, FILE_ENDPOINT, params=params, max_entries=max_entries
142
- )
154
+ response_stream = paginated_request(client, FILE_ENDPOINT, params=params, max_entries=max_entries)
143
155
  yield from map(lambda f: _parse_file(FileObject(f)), response_stream)
144
156
 
145
157
 
@@ -149,9 +161,7 @@ def get_missions(
149
161
  max_entries: Optional[int] = None,
150
162
  ) -> Generator[Mission, None, None]:
151
163
  params = _mission_query_to_params(mission_query)
152
- response_stream = paginated_request(
153
- client, MISSION_ENDPOINT, params=params, max_entries=max_entries
154
- )
164
+ response_stream = paginated_request(client, MISSION_ENDPOINT, params=params, max_entries=max_entries)
155
165
  yield from map(lambda m: _parse_mission(MissionObject(m)), response_stream)
156
166
 
157
167
 
@@ -159,36 +169,93 @@ def get_projects(
159
169
  client: AuthenticatedClient,
160
170
  project_query: ProjectQuery,
161
171
  max_entries: Optional[int] = None,
172
+ exact_match: bool = False,
162
173
  ) -> Generator[Project, None, None]:
163
174
  params = _project_query_to_params(project_query)
164
175
  response_stream = paginated_request(
165
- client, PROJECT_ENDPOINT, params=params, max_entries=max_entries
176
+ client,
177
+ PROJECT_ENDPOINT,
178
+ params=params,
179
+ max_entries=max_entries,
180
+ exact_match=exact_match,
166
181
  )
167
182
  yield from map(lambda p: _parse_project(ProjectObject(p)), response_stream)
168
183
 
169
184
 
170
- def get_project(client: AuthenticatedClient, query: ProjectQuery) -> Project:
185
+ LIST_ACTIONS_ENDPOINT = "/actions"
186
+
187
+
188
+ def get_runs(
189
+ client: AuthenticatedClient,
190
+ query: RunQuery,
191
+ ) -> Generator[Run, None, None]:
192
+
193
+ response_stream = paginated_request(client, LIST_ACTIONS_ENDPOINT)
194
+ yield from map(lambda p: _parse_run(RunObject(p)), response_stream)
195
+
196
+
197
+ def get_run(
198
+ client: AuthenticatedClient,
199
+ run_id: str,
200
+ ) -> Run:
201
+ resp = client.get(f"{ACTION_ENDPOINT}s/{run_id}")
202
+ if resp.status_code == 404:
203
+ raise kleinkram.errors.RunNotFound(f"Run not found: {run_id}")
204
+ resp.raise_for_status()
205
+ return _parse_run(RunObject(resp.json()))
206
+
207
+
208
+ def get_action_templates(
209
+ client: AuthenticatedClient,
210
+ ) -> Generator[ActionTemplate, None, None]:
211
+ response_stream = paginated_request(client, "/templates")
212
+ yield from map(lambda p: _parse_action_template(RunObject(p)), response_stream)
213
+
214
+
215
+ def get_project(client: AuthenticatedClient, query: ProjectQuery, exact_match: bool = False) -> Project:
171
216
  """\
172
217
  get a unique project by specifying a project spec
173
218
  """
174
219
  if not project_query_is_unique(query):
175
- raise InvalidProjectQuery(
176
- f"Project query does not uniquely determine project: {query}"
177
- )
220
+ raise InvalidProjectQuery(f"Project query does not uniquely determine project: {query}")
178
221
  try:
179
- return next(get_projects(client, query))
222
+ return next(get_projects(client, query, exact_match=exact_match))
180
223
  except StopIteration:
181
224
  raise ProjectNotFound(f"Project not found: {query}")
182
225
 
183
226
 
227
+ def submit_action(client: AuthenticatedClient, mission_uuid: UUID, template_uuid: UUID) -> str:
228
+ """
229
+ Submits a new action to the API and returns the action UUID.
230
+
231
+ Raises:
232
+ httpx.HTTPStatusError: If the API returns an error.
233
+ KeyError: If the response is missing 'actionUUID'.
234
+ """
235
+ submit_payload = {
236
+ "missionUUID": str(mission_uuid),
237
+ "templateUUID": str(template_uuid),
238
+ }
239
+
240
+ typer.echo("Submitting action...")
241
+ resp = client.post(f"{ACTION_ENDPOINT}s", json=submit_payload)
242
+ resp.raise_for_status() # Raises on 4xx/5xx responses
243
+
244
+ response_data = resp.json()
245
+ action_uuid_str = response_data.get("actionUUID")
246
+
247
+ if not action_uuid_str:
248
+ raise KeyError("API response missing 'actionUUID'")
249
+
250
+ return action_uuid_str
251
+
252
+
184
253
  def get_mission(client: AuthenticatedClient, query: MissionQuery) -> Mission:
185
254
  """\
186
255
  get a unique mission by specifying a mission query
187
256
  """
188
257
  if not mission_query_is_unique(query):
189
- raise InvalidMissionQuery(
190
- f"Mission query does not uniquely determine mission: {query}"
191
- )
258
+ raise InvalidMissionQuery(f"Mission query does not uniquely determine mission: {query}")
192
259
  try:
193
260
  return next(get_missions(client, query))
194
261
  except StopIteration:
@@ -207,12 +274,8 @@ def get_file(client: AuthenticatedClient, query: FileQuery) -> File:
207
274
  raise kleinkram.errors.FileNotFound(f"File not found: {query}")
208
275
 
209
276
 
210
- def _mission_name_is_available(
211
- client: AuthenticatedClient, mission_name: str, project_id: UUID
212
- ) -> bool:
213
- mission_query = MissionQuery(
214
- patterns=[mission_name], project_query=ProjectQuery(ids=[project_id])
215
- )
277
+ def _mission_name_is_available(client: AuthenticatedClient, mission_name: str, project_id: UUID) -> bool:
278
+ mission_query = MissionQuery(patterns=[mission_name], project_query=ProjectQuery(ids=[project_id]))
216
279
  try:
217
280
  _ = get_mission(client, mission_query)
218
281
  except MissionNotFound:
@@ -220,15 +283,68 @@ def _mission_name_is_available(
220
283
  return False
221
284
 
222
285
 
286
+ def _validate_mission_name(client: AuthenticatedClient, project_id: UUID, mission_name: str) -> None:
287
+ if not _mission_name_is_available(client, mission_name, project_id):
288
+ raise MissionExists(f"Mission with name: `{mission_name}` already exists" f" in project: {project_id}")
289
+
290
+ if is_valid_uuid4(mission_name):
291
+ raise ValueError(f"Mission name: `{mission_name}` is a valid UUIDv4, " "mission names must not be valid UUIDv4's")
292
+
293
+ if mission_name.endswith(" "):
294
+ raise ValueError("A mission name cannot end with a whitespace. " f"The given mission name was '{mission_name}'")
295
+
296
+
223
297
  def _project_name_is_available(client: AuthenticatedClient, project_name: str) -> bool:
224
298
  project_query = ProjectQuery(patterns=[project_name])
225
299
  try:
226
- _ = get_project(client, project_query)
300
+ _ = get_project(client, project_query, exact_match=True)
227
301
  except ProjectNotFound:
228
302
  return True
229
303
  return False
230
304
 
231
305
 
306
+ def _validate_mission_created(client: AuthenticatedClient, project_id: str, mission_name: str) -> None:
307
+ """
308
+ validate that a mission is successfully created
309
+ """
310
+ mission_ids, mission_patterns = split_args([mission_name])
311
+ project_ids, project_patterns = split_args([project_id])
312
+
313
+ project_query = ProjectQuery(ids=project_ids, patterns=project_patterns)
314
+ mission_query = MissionQuery(
315
+ ids=mission_ids,
316
+ patterns=mission_patterns,
317
+ project_query=project_query,
318
+ )
319
+ try:
320
+ with tempfile.NamedTemporaryFile(suffix=".mcap", delete=False) as tmp:
321
+ tmp.write(b"dummy content")
322
+ tmp_path = Path(tmp.name)
323
+
324
+ kleinkram.core.upload(
325
+ client=client,
326
+ query=mission_query,
327
+ file_paths=[tmp_path],
328
+ verbose=False,
329
+ )
330
+
331
+ file_query = FileQuery(
332
+ ids=[],
333
+ patterns=[tmp_path.name],
334
+ mission_query=mission_query,
335
+ )
336
+ file_parsed = get_file(client, file_query)
337
+
338
+ kleinkram.core.delete_files(client=client, file_ids=[file_parsed.id])
339
+
340
+ except Exception as e:
341
+ raise MissionValidationError(f"Mission validation failed: {e}")
342
+
343
+ finally:
344
+ if tmp_path.exists():
345
+ tmp_path.unlink()
346
+
347
+
232
348
  def _create_mission(
233
349
  client: AuthenticatedClient,
234
350
  project_id: UUID,
@@ -236,6 +352,7 @@ def _create_mission(
236
352
  *,
237
353
  metadata: Optional[Dict[str, str]] = None,
238
354
  ignore_missing_tags: bool = False,
355
+ required_tags: Optional[List[str]] = None,
239
356
  ) -> UUID:
240
357
  """\
241
358
  creates a new mission with the given name and project_id
@@ -246,17 +363,10 @@ def _create_mission(
246
363
  if metadata is None:
247
364
  metadata = {}
248
365
 
249
- if not _mission_name_is_available(client, mission_name, project_id):
250
- raise MissionExists(
251
- f"Mission with name: `{mission_name}` already exists"
252
- f" in project: {project_id}"
253
- )
366
+ _validate_mission_name(client, project_id, mission_name)
254
367
 
255
- if is_valid_uuid4(mission_name):
256
- raise ValueError(
257
- f"Mission name: `{mission_name}` is a valid UUIDv4, "
258
- "mission names must not be valid UUIDv4's"
259
- )
368
+ if required_tags and not set(required_tags).issubset(metadata.keys()):
369
+ raise InvalidMissionMetadata(f"Mission tags `{required_tags}` are required but missing from metadata: {metadata}")
260
370
 
261
371
  # we need to translate tag keys to tag type ids
262
372
  tags = _get_tags_map(client, metadata)
@@ -267,20 +377,16 @@ def _create_mission(
267
377
  "tags": {str(k): v for k, v in tags.items()},
268
378
  "ignoreTags": ignore_missing_tags,
269
379
  }
270
-
271
380
  resp = client.post(CREATE_MISSION, json=payload)
272
381
  resp.raise_for_status()
382
+ _validate_mission_created(client, str(project_id), mission_name)
273
383
 
274
384
  return UUID(resp.json()["uuid"], version=4)
275
385
 
276
386
 
277
- def _create_project(
278
- client: AuthenticatedClient, project_name: str, description: str
279
- ) -> UUID:
280
- if not _project_name_is_available(client, project_name):
281
- raise ProjectExists(f"Project with name: `{project_name}` already exists")
387
+ def _create_project(client: AuthenticatedClient, project_name: str, description: str) -> UUID:
282
388
 
283
- # TODO: check name and description are valid
389
+ _validate_project_name(client, project_name, description)
284
390
  payload = {"name": project_name, "description": description}
285
391
  resp = client.post(CREATE_PROJECT, json=payload)
286
392
  resp.raise_for_status()
@@ -288,37 +394,60 @@ def _create_project(
288
394
  return UUID(resp.json()["uuid"], version=4)
289
395
 
290
396
 
291
- def _get_metadata_type_id_by_name(
292
- client: AuthenticatedClient, tag_name: str
293
- ) -> Optional[UUID]:
397
+ def _validate_project_name(client: AuthenticatedClient, project_name: str, description: str) -> None:
398
+ if not _project_name_is_available(client, project_name):
399
+ raise ProjectExists(f"Project with name: `{project_name}` already exists")
400
+
401
+ if project_name.endswith(" "):
402
+ raise ProjectValidationError(f"Project name must not end with a tailing whitespace: `{project_name}`")
403
+
404
+ if not description:
405
+ raise ProjectValidationError("Project description is required")
406
+
407
+
408
+ def _validate_tag_value(tag_value, tag_datatype) -> None:
409
+ if tag_datatype == "NUMBER":
410
+ try:
411
+ float(tag_value)
412
+ except ValueError:
413
+ raise InvalidMissionMetadata(f"Value '{tag_value}' is not a valid NUMBER")
414
+ elif tag_datatype == "BOOLEAN":
415
+ if tag_value.lower() not in {"true", "false"}:
416
+ raise InvalidMissionMetadata(f"Value '{tag_value}' is not a valid BOOLEAN (expected 'true' or 'false')")
417
+ else:
418
+ pass # any string is fine
419
+ # TODO: add check for LOCATION tag datatype
420
+
421
+
422
+ def _get_metadata_type_id_by_name(client: AuthenticatedClient, tag_name: str) -> Tuple[Optional[UUID], str]:
294
423
  resp = client.get(TAG_TYPE_BY_NAME, params={"name": tag_name, "take": 1})
295
424
 
296
425
  if resp.status_code in (403, 404):
297
426
  return None
298
427
 
299
428
  resp.raise_for_status()
429
+ try:
430
+ data = resp.json()["data"][0]
431
+ except IndexError:
432
+ return None, None
300
433
 
301
- data = resp.json()[0]
302
- return UUID(data["uuid"], version=4)
434
+ return UUID(data["uuid"], version=4), data["datatype"]
303
435
 
304
436
 
305
- def _get_tags_map(
306
- client: AuthenticatedClient, metadata: Dict[str, str]
307
- ) -> Dict[UUID, str]:
437
+ def _get_tags_map(client: AuthenticatedClient, metadata: Dict[str, str]) -> Dict[UUID, str]:
308
438
  # TODO: this needs a better endpoint
309
439
  # why are we using metadata type ids as keys???
310
440
  ret = {}
311
441
  for key, val in metadata.items():
312
- metadata_type_id = _get_metadata_type_id_by_name(client, key)
442
+ metadata_type_id, tag_datatype = _get_metadata_type_id_by_name(client, key)
313
443
  if metadata_type_id is None:
314
444
  raise InvalidMissionMetadata(f"metadata field: {key} does not exist")
445
+ _validate_tag_value(val, tag_datatype)
315
446
  ret[metadata_type_id] = val
316
447
  return ret
317
448
 
318
449
 
319
- def _update_mission(
320
- client: AuthenticatedClient, mission_id: UUID, *, metadata: Dict[str, str]
321
- ) -> None:
450
+ def _update_mission(client: AuthenticatedClient, mission_id: UUID, *, metadata: Dict[str, str]) -> None:
322
451
  tags_dct = _get_tags_map(client, metadata)
323
452
  payload = {
324
453
  "missionUUID": str(mission_id),
@@ -357,12 +486,18 @@ def _get_api_version() -> Tuple[int, int, int]:
357
486
  config = get_config()
358
487
  client = httpx.Client()
359
488
 
360
- resp = client.get(
361
- f"{config.endpoint.api}{GET_STATUS}", headers={CLI_VERSION_HEADER: __version__}
362
- )
363
- vers = resp.headers["kleinkram-version"].split(".")
489
+ resp = client.get(f"{config.endpoint.api}{GET_STATUS}", headers={CLI_VERSION_HEADER: __version__})
490
+ vers_str = resp.headers.get("kleinkram-version")
364
491
 
365
- return tuple(map(int, vers)) # type: ignore
492
+ if not vers_str:
493
+ return (0, 0, 0)
494
+
495
+ vers = vers_str.split(".")
496
+
497
+ try:
498
+ return tuple(map(int, vers)) # type: ignore
499
+ except ValueError:
500
+ return (0, 0, 0)
366
501
 
367
502
 
368
503
  def _claim_admin(client: AuthenticatedClient) -> None:
@@ -377,9 +512,7 @@ def _claim_admin(client: AuthenticatedClient) -> None:
377
512
  FILE_DELETE_MANY = "/files/deleteMultiple"
378
513
 
379
514
 
380
- def _delete_files(
381
- client: AuthenticatedClient, file_ids: Sequence[UUID], mission_id: UUID
382
- ) -> None:
515
+ def _delete_files(client: AuthenticatedClient, file_ids: Sequence[UUID], mission_id: UUID) -> None:
383
516
  payload = {
384
517
  "uuids": [str(file_id) for file_id in file_ids],
385
518
  "missionUUID": str(mission_id),
kleinkram/auth.py CHANGED
@@ -13,7 +13,7 @@ from kleinkram.config import get_config
13
13
  from kleinkram.config import save_config
14
14
 
15
15
  CLI_CALLBACK_ENDPOINT = "/cli/callback"
16
- OAUTH_SLUG = "/auth/google?state=cli"
16
+ OAUTH_SLUG = "/auth/"
17
17
 
18
18
 
19
19
  def _has_browser() -> bool:
@@ -33,9 +33,7 @@ def _headless_auth(*, url: str) -> None:
33
33
 
34
34
  if auth_token and refresh_token:
35
35
  config = get_config()
36
- config.credentials = Credentials(
37
- auth_token=auth_token, refresh_token=refresh_token
38
- )
36
+ config.credentials = Credentials(auth_token=auth_token, refresh_token=refresh_token)
39
37
  save_config(config)
40
38
  print(f"Authentication complete. Tokens saved to {CONFIG_PATH}.")
41
39
  else:
@@ -80,7 +78,89 @@ def _browser_auth(*, url: str) -> None:
80
78
  print(f"Authentication complete. Tokens saved to {CONFIG_PATH}.")
81
79
 
82
80
 
83
- def login_flow(*, key: Optional[str] = None, headless: bool = False) -> None:
81
+ def _direct_oauth_auth(*, endpoint: str, provider: str, user: str) -> None:
82
+ """
83
+ Directly authenticate with fake OAuth by programmatically following the OAuth flow.
84
+ This bypasses the browser entirely for automated testing.
85
+ """
86
+ import requests
87
+
88
+ print(f"Authenticating as user {user} with {provider}...")
89
+
90
+ try:
91
+ # Step 1: Get the authorization code from fake OAuth
92
+ # The fake OAuth server will auto-redirect when user parameter is provided
93
+ fake_oauth_url = "http://localhost:8004/oauth/authorize"
94
+ callback_url = f"{endpoint}/auth/{provider}/callback"
95
+
96
+ params = {
97
+ "client_id": "some-random-string-it-does-not-matter",
98
+ "redirect_uri": callback_url,
99
+ "response_type": "code",
100
+ "state": "cli-direct",
101
+ "user": user,
102
+ }
103
+
104
+ # Make request to fake OAuth - it will redirect with the auth code
105
+ response = requests.get(fake_oauth_url, params=params, allow_redirects=False)
106
+
107
+ if response.status_code not in [301, 302, 303, 307, 308]:
108
+ raise RuntimeError(f"Expected redirect from OAuth provider, got {response.status_code}")
109
+
110
+ # Extract the redirect location
111
+ location = response.headers.get("Location")
112
+ if not location:
113
+ raise RuntimeError("No redirect location from OAuth provider")
114
+
115
+ # Parse the callback URL to extract the auth code
116
+ parsed = urllib.parse.urlparse(location)
117
+ query_params = urllib.parse.parse_qs(parsed.query)
118
+
119
+ if "code" not in query_params:
120
+ raise RuntimeError(f"No authorization code in redirect: {location}")
121
+
122
+ auth_code = query_params["code"][0]
123
+ state = query_params.get("state", [None])[0]
124
+
125
+ print("Received authorization code, exchanging for tokens...")
126
+
127
+ # Step 2: Exchange the code for tokens by calling the backend callback
128
+ # Use a session to preserve cookies
129
+ session = requests.Session()
130
+ callback_params = {"code": auth_code}
131
+ if state:
132
+ callback_params["state"] = state
133
+
134
+ callback_response = session.get(callback_url, params=callback_params, allow_redirects=False)
135
+
136
+ # The backend should set cookies and redirect
137
+ if callback_response.status_code not in [301, 302, 303, 307, 308]:
138
+ raise RuntimeError(f"Expected redirect from callback, got {callback_response.status_code}")
139
+
140
+ # Extract tokens from cookies
141
+ auth_token = session.cookies.get("authtoken")
142
+ refresh_token = session.cookies.get("refreshtoken")
143
+
144
+ if not auth_token or not refresh_token:
145
+ raise RuntimeError("Failed to get tokens from callback response")
146
+
147
+ # Save tokens
148
+ config = get_config()
149
+ config.credentials = Credentials(auth_token=auth_token, refresh_token=refresh_token)
150
+ save_config(config)
151
+ print(f"Authentication complete. Tokens saved to {CONFIG_PATH}.")
152
+
153
+ except requests.RequestException as e:
154
+ raise RuntimeError(f"OAuth flow failed: {e}")
155
+
156
+
157
+ def login_flow(
158
+ *,
159
+ oAuthProvider: str,
160
+ key: Optional[str] = None,
161
+ headless: bool = False,
162
+ user: Optional[str] = None,
163
+ ) -> None:
84
164
  config = get_config()
85
165
  # use cli key login
86
166
  if key is not None:
@@ -88,8 +168,29 @@ def login_flow(*, key: Optional[str] = None, headless: bool = False) -> None:
88
168
  save_config(config)
89
169
  return
90
170
 
91
- oauth_url = f"{config.endpoint.api}{OAUTH_SLUG}"
92
- if not headless and _has_browser():
171
+ # If user parameter is provided with fake-oauth, use direct OAuth flow
172
+ if user is not None and oAuthProvider == "fake-oauth":
173
+ _direct_oauth_auth(endpoint=config.endpoint.api, provider=oAuthProvider, user=user)
174
+ return
175
+
176
+ # Build OAuth URL with state parameter
177
+ oauth_url = f"{config.endpoint.api}{OAUTH_SLUG}{oAuthProvider}?state=cli"
178
+
179
+ # Add user parameter if provided (for fake-oauth auto-login)
180
+ if user is not None:
181
+ oauth_url += f"&user={user}"
182
+
183
+ is_port_available = True
184
+ try:
185
+ server = HTTPServer(("", 8000), OAuthCallbackHandler)
186
+ server.server_close()
187
+ except OSError:
188
+ is_port_available = False
189
+
190
+ if not is_port_available:
191
+ print("Warning: Port 8000 is not available. Falling back to headless authentication.\n\n")
192
+
193
+ if not headless and _has_browser() and is_port_available:
93
194
  _browser_auth(url=oauth_url)
94
195
  else:
95
196
  _headless_auth(url=f"{oauth_url}-no-redirect")