atdata 0.1.3b4__py3-none-any.whl → 0.2.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,533 @@
1
+ """ATProto client wrapper for atdata.
2
+
3
+ This module provides the ``AtmosphereClient`` class which wraps the atproto SDK
4
+ client with atdata-specific helpers for publishing and querying records.
5
+ """
6
+
7
+ from typing import Optional, Any
8
+
9
+ from ._types import AtUri, LEXICON_NAMESPACE
10
+
11
+ # Lazy import to avoid requiring atproto if not using atmosphere features
12
+ _atproto_client_class: Optional[type] = None
13
+
14
+
15
+ def _get_atproto_client_class():
16
+ """Lazily import the atproto Client class."""
17
+ global _atproto_client_class
18
+ if _atproto_client_class is None:
19
+ try:
20
+ from atproto import Client
21
+ _atproto_client_class = Client
22
+ except ImportError as e:
23
+ raise ImportError(
24
+ "The 'atproto' package is required for ATProto integration. "
25
+ "Install it with: pip install atproto"
26
+ ) from e
27
+ return _atproto_client_class
28
+
29
+
30
+ class AtmosphereClient:
31
+ """ATProto client wrapper for atdata operations.
32
+
33
+ This class wraps the atproto SDK client and provides higher-level methods
34
+ for working with atdata records (schemas, datasets, lenses).
35
+
36
+ Example:
37
+ ::
38
+
39
+ >>> client = AtmosphereClient()
40
+ >>> client.login("alice.bsky.social", "app-password")
41
+ >>> print(client.did)
42
+ 'did:plc:...'
43
+
44
+ Note:
45
+ The password should be an app-specific password, not your main account
46
+ password. Create app passwords in your Bluesky account settings.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ base_url: Optional[str] = None,
52
+ *,
53
+ _client: Optional[Any] = None,
54
+ ):
55
+ """Initialize the ATProto client.
56
+
57
+ Args:
58
+ base_url: Optional PDS base URL. Defaults to bsky.social.
59
+ _client: Optional pre-configured atproto Client for testing.
60
+ """
61
+ if _client is not None:
62
+ self._client = _client
63
+ else:
64
+ Client = _get_atproto_client_class()
65
+ self._client = Client(base_url=base_url) if base_url else Client()
66
+
67
+ self._session: Optional[dict] = None
68
+
69
+ def login(self, handle: str, password: str) -> None:
70
+ """Authenticate with the ATProto PDS.
71
+
72
+ Args:
73
+ handle: Your Bluesky handle (e.g., 'alice.bsky.social').
74
+ password: App-specific password (not your main password).
75
+
76
+ Raises:
77
+ atproto.exceptions.AtProtocolError: If authentication fails.
78
+ """
79
+ profile = self._client.login(handle, password)
80
+ self._session = {
81
+ "did": profile.did,
82
+ "handle": profile.handle,
83
+ }
84
+
85
+ def login_with_session(self, session_string: str) -> None:
86
+ """Authenticate using an exported session string.
87
+
88
+ This allows reusing a session without re-authenticating, which helps
89
+ avoid rate limits on session creation.
90
+
91
+ Args:
92
+ session_string: Session string from ``export_session()``.
93
+ """
94
+ self._client.login(session_string=session_string)
95
+ self._session = {
96
+ "did": self._client.me.did,
97
+ "handle": self._client.me.handle,
98
+ }
99
+
100
+ def export_session(self) -> str:
101
+ """Export the current session for later reuse.
102
+
103
+ Returns:
104
+ Session string that can be passed to ``login_with_session()``.
105
+
106
+ Raises:
107
+ ValueError: If not authenticated.
108
+ """
109
+ if not self.is_authenticated:
110
+ raise ValueError("Not authenticated")
111
+ return self._client.export_session_string()
112
+
113
+ @property
114
+ def is_authenticated(self) -> bool:
115
+ """Check if the client has a valid session."""
116
+ return self._session is not None
117
+
118
+ @property
119
+ def did(self) -> str:
120
+ """Get the DID of the authenticated user.
121
+
122
+ Returns:
123
+ The DID string (e.g., 'did:plc:...').
124
+
125
+ Raises:
126
+ ValueError: If not authenticated.
127
+ """
128
+ if not self._session:
129
+ raise ValueError("Not authenticated")
130
+ return self._session["did"]
131
+
132
+ @property
133
+ def handle(self) -> str:
134
+ """Get the handle of the authenticated user.
135
+
136
+ Returns:
137
+ The handle string (e.g., 'alice.bsky.social').
138
+
139
+ Raises:
140
+ ValueError: If not authenticated.
141
+ """
142
+ if not self._session:
143
+ raise ValueError("Not authenticated")
144
+ return self._session["handle"]
145
+
146
+ def _ensure_authenticated(self) -> None:
147
+ """Raise if not authenticated."""
148
+ if not self.is_authenticated:
149
+ raise ValueError("Client must be authenticated to perform this operation")
150
+
151
+ # Low-level record operations
152
+
153
+ def create_record(
154
+ self,
155
+ collection: str,
156
+ record: dict,
157
+ *,
158
+ rkey: Optional[str] = None,
159
+ validate: bool = False,
160
+ ) -> AtUri:
161
+ """Create a record in the user's repository.
162
+
163
+ Args:
164
+ collection: The NSID of the record collection
165
+ (e.g., 'ac.foundation.dataset.sampleSchema').
166
+ record: The record data. Must include a '$type' field.
167
+ rkey: Optional explicit record key. If not provided, a TID is generated.
168
+ validate: Whether to validate against the Lexicon schema. Set to False
169
+ for custom lexicons that the PDS doesn't know about.
170
+
171
+ Returns:
172
+ The AT URI of the created record.
173
+
174
+ Raises:
175
+ ValueError: If not authenticated.
176
+ atproto.exceptions.AtProtocolError: If record creation fails.
177
+ """
178
+ self._ensure_authenticated()
179
+
180
+ response = self._client.com.atproto.repo.create_record(
181
+ data={
182
+ "repo": self.did,
183
+ "collection": collection,
184
+ "record": record,
185
+ "rkey": rkey,
186
+ "validate": validate,
187
+ }
188
+ )
189
+
190
+ return AtUri.parse(response.uri)
191
+
192
+ def put_record(
193
+ self,
194
+ collection: str,
195
+ rkey: str,
196
+ record: dict,
197
+ *,
198
+ validate: bool = False,
199
+ swap_commit: Optional[str] = None,
200
+ ) -> AtUri:
201
+ """Create or update a record at a specific key.
202
+
203
+ Args:
204
+ collection: The NSID of the record collection.
205
+ rkey: The record key.
206
+ record: The record data. Must include a '$type' field.
207
+ validate: Whether to validate against the Lexicon schema.
208
+ swap_commit: Optional CID for compare-and-swap update.
209
+
210
+ Returns:
211
+ The AT URI of the record.
212
+
213
+ Raises:
214
+ ValueError: If not authenticated.
215
+ atproto.exceptions.AtProtocolError: If operation fails.
216
+ """
217
+ self._ensure_authenticated()
218
+
219
+ data: dict[str, Any] = {
220
+ "repo": self.did,
221
+ "collection": collection,
222
+ "rkey": rkey,
223
+ "record": record,
224
+ "validate": validate,
225
+ }
226
+ if swap_commit:
227
+ data["swapCommit"] = swap_commit
228
+
229
+ response = self._client.com.atproto.repo.put_record(data=data)
230
+
231
+ return AtUri.parse(response.uri)
232
+
233
+ def get_record(
234
+ self,
235
+ uri: str | AtUri,
236
+ ) -> dict:
237
+ """Fetch a record by AT URI.
238
+
239
+ Args:
240
+ uri: The AT URI of the record.
241
+
242
+ Returns:
243
+ The record data as a dictionary.
244
+
245
+ Raises:
246
+ atproto.exceptions.AtProtocolError: If record not found.
247
+ """
248
+ if isinstance(uri, str):
249
+ uri = AtUri.parse(uri)
250
+
251
+ response = self._client.com.atproto.repo.get_record(
252
+ params={
253
+ "repo": uri.authority,
254
+ "collection": uri.collection,
255
+ "rkey": uri.rkey,
256
+ }
257
+ )
258
+
259
+ # Convert ATProto model to dict if needed
260
+ value = response.value
261
+ # DotDict and similar ATProto models have to_dict()
262
+ if hasattr(value, "to_dict") and callable(value.to_dict):
263
+ return value.to_dict()
264
+ elif isinstance(value, dict):
265
+ return dict(value)
266
+ elif hasattr(value, "model_dump") and callable(value.model_dump):
267
+ return value.model_dump()
268
+ elif hasattr(value, "__dict__"):
269
+ return dict(value.__dict__)
270
+ return value
271
+
272
+ def delete_record(
273
+ self,
274
+ uri: str | AtUri,
275
+ *,
276
+ swap_commit: Optional[str] = None,
277
+ ) -> None:
278
+ """Delete a record.
279
+
280
+ Args:
281
+ uri: The AT URI of the record to delete.
282
+ swap_commit: Optional CID for compare-and-swap delete.
283
+
284
+ Raises:
285
+ ValueError: If not authenticated.
286
+ atproto.exceptions.AtProtocolError: If deletion fails.
287
+ """
288
+ self._ensure_authenticated()
289
+
290
+ if isinstance(uri, str):
291
+ uri = AtUri.parse(uri)
292
+
293
+ data: dict[str, Any] = {
294
+ "repo": self.did,
295
+ "collection": uri.collection,
296
+ "rkey": uri.rkey,
297
+ }
298
+ if swap_commit:
299
+ data["swapCommit"] = swap_commit
300
+
301
+ self._client.com.atproto.repo.delete_record(data=data)
302
+
303
+ def upload_blob(
304
+ self,
305
+ data: bytes,
306
+ mime_type: str = "application/octet-stream",
307
+ ) -> dict:
308
+ """Upload binary data as a blob to the PDS.
309
+
310
+ Args:
311
+ data: Binary data to upload.
312
+ mime_type: MIME type of the data (for reference, not enforced by PDS).
313
+
314
+ Returns:
315
+ A blob reference dict with keys: '$type', 'ref', 'mimeType', 'size'.
316
+ This can be embedded directly in record fields.
317
+
318
+ Raises:
319
+ ValueError: If not authenticated.
320
+ atproto.exceptions.AtProtocolError: If upload fails.
321
+ """
322
+ self._ensure_authenticated()
323
+
324
+ response = self._client.upload_blob(data)
325
+ blob_ref = response.blob
326
+
327
+ # Convert to dict format suitable for embedding in records
328
+ return {
329
+ "$type": "blob",
330
+ "ref": {"$link": blob_ref.ref.link if hasattr(blob_ref.ref, "link") else str(blob_ref.ref)},
331
+ "mimeType": blob_ref.mime_type,
332
+ "size": blob_ref.size,
333
+ }
334
+
335
+ def get_blob(
336
+ self,
337
+ did: str,
338
+ cid: str,
339
+ ) -> bytes:
340
+ """Download a blob from a PDS.
341
+
342
+ This resolves the PDS endpoint from the DID document and fetches
343
+ the blob directly from the PDS.
344
+
345
+ Args:
346
+ did: The DID of the repository containing the blob.
347
+ cid: The CID of the blob.
348
+
349
+ Returns:
350
+ The blob data as bytes.
351
+
352
+ Raises:
353
+ ValueError: If PDS endpoint cannot be resolved.
354
+ requests.HTTPError: If blob fetch fails.
355
+ """
356
+ import requests
357
+
358
+ # Resolve PDS endpoint from DID document
359
+ pds_endpoint = self._resolve_pds_endpoint(did)
360
+ if not pds_endpoint:
361
+ raise ValueError(f"Could not resolve PDS endpoint for {did}")
362
+
363
+ # Fetch blob from PDS
364
+ url = f"{pds_endpoint}/xrpc/com.atproto.sync.getBlob"
365
+ response = requests.get(url, params={"did": did, "cid": cid})
366
+ response.raise_for_status()
367
+ return response.content
368
+
369
+ def _resolve_pds_endpoint(self, did: str) -> Optional[str]:
370
+ """Resolve the PDS endpoint for a DID.
371
+
372
+ Args:
373
+ did: The DID to resolve.
374
+
375
+ Returns:
376
+ The PDS service endpoint URL, or None if not found.
377
+ """
378
+ import requests
379
+
380
+ # For did:plc, query the PLC directory
381
+ if did.startswith("did:plc:"):
382
+ try:
383
+ response = requests.get(f"https://plc.directory/{did}")
384
+ response.raise_for_status()
385
+ did_doc = response.json()
386
+
387
+ for service in did_doc.get("service", []):
388
+ if service.get("type") == "AtprotoPersonalDataServer":
389
+ return service.get("serviceEndpoint")
390
+ except requests.RequestException:
391
+ return None
392
+
393
+ # For did:web, would need different resolution (not implemented)
394
+ return None
395
+
396
+ def get_blob_url(self, did: str, cid: str) -> str:
397
+ """Get the direct URL for fetching a blob.
398
+
399
+ This is useful for passing to WebDataset or other HTTP clients.
400
+
401
+ Args:
402
+ did: The DID of the repository containing the blob.
403
+ cid: The CID of the blob.
404
+
405
+ Returns:
406
+ The full URL for fetching the blob.
407
+
408
+ Raises:
409
+ ValueError: If PDS endpoint cannot be resolved.
410
+ """
411
+ pds_endpoint = self._resolve_pds_endpoint(did)
412
+ if not pds_endpoint:
413
+ raise ValueError(f"Could not resolve PDS endpoint for {did}")
414
+ return f"{pds_endpoint}/xrpc/com.atproto.sync.getBlob?did={did}&cid={cid}"
415
+
416
+ def list_records(
417
+ self,
418
+ collection: str,
419
+ *,
420
+ repo: Optional[str] = None,
421
+ limit: int = 100,
422
+ cursor: Optional[str] = None,
423
+ ) -> tuple[list[dict], Optional[str]]:
424
+ """List records in a collection.
425
+
426
+ Args:
427
+ collection: The NSID of the record collection.
428
+ repo: The DID of the repository to query. Defaults to the
429
+ authenticated user's repository.
430
+ limit: Maximum number of records to return (default 100).
431
+ cursor: Pagination cursor from a previous call.
432
+
433
+ Returns:
434
+ A tuple of (records, next_cursor). The cursor is None if there
435
+ are no more records.
436
+
437
+ Raises:
438
+ ValueError: If repo is None and not authenticated.
439
+ """
440
+ if repo is None:
441
+ self._ensure_authenticated()
442
+ repo = self.did
443
+
444
+ response = self._client.com.atproto.repo.list_records(
445
+ params={
446
+ "repo": repo,
447
+ "collection": collection,
448
+ "limit": limit,
449
+ "cursor": cursor,
450
+ }
451
+ )
452
+
453
+ # Convert ATProto models to dicts if needed
454
+ records = []
455
+ for r in response.records:
456
+ value = r.value
457
+ # DotDict and similar ATProto models have to_dict()
458
+ if hasattr(value, "to_dict") and callable(value.to_dict):
459
+ records.append(value.to_dict())
460
+ elif isinstance(value, dict):
461
+ records.append(dict(value))
462
+ elif hasattr(value, "model_dump") and callable(value.model_dump):
463
+ records.append(value.model_dump())
464
+ elif hasattr(value, "__dict__"):
465
+ records.append(dict(value.__dict__))
466
+ else:
467
+ records.append(value)
468
+ return records, response.cursor
469
+
470
+ # Convenience methods for atdata collections
471
+
472
+ def list_schemas(
473
+ self,
474
+ repo: Optional[str] = None,
475
+ limit: int = 100,
476
+ ) -> list[dict]:
477
+ """List schema records.
478
+
479
+ Args:
480
+ repo: The DID to query. Defaults to authenticated user.
481
+ limit: Maximum number to return.
482
+
483
+ Returns:
484
+ List of schema records.
485
+ """
486
+ records, _ = self.list_records(
487
+ f"{LEXICON_NAMESPACE}.sampleSchema",
488
+ repo=repo,
489
+ limit=limit,
490
+ )
491
+ return records
492
+
493
+ def list_datasets(
494
+ self,
495
+ repo: Optional[str] = None,
496
+ limit: int = 100,
497
+ ) -> list[dict]:
498
+ """List dataset records.
499
+
500
+ Args:
501
+ repo: The DID to query. Defaults to authenticated user.
502
+ limit: Maximum number to return.
503
+
504
+ Returns:
505
+ List of dataset records.
506
+ """
507
+ records, _ = self.list_records(
508
+ f"{LEXICON_NAMESPACE}.record",
509
+ repo=repo,
510
+ limit=limit,
511
+ )
512
+ return records
513
+
514
+ def list_lenses(
515
+ self,
516
+ repo: Optional[str] = None,
517
+ limit: int = 100,
518
+ ) -> list[dict]:
519
+ """List lens records.
520
+
521
+ Args:
522
+ repo: The DID to query. Defaults to authenticated user.
523
+ limit: Maximum number to return.
524
+
525
+ Returns:
526
+ List of lens records.
527
+ """
528
+ records, _ = self.list_records(
529
+ f"{LEXICON_NAMESPACE}.lens",
530
+ repo=repo,
531
+ limit=limit,
532
+ )
533
+ return records