wandb 0.18.0__py3-none-win32.whl → 0.18.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/runs.py +2 -0
  4. wandb/bin/wandb-core +0 -0
  5. wandb/cli/cli.py +0 -2
  6. wandb/data_types.py +9 -2019
  7. wandb/env.py +0 -5
  8. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  9. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  10. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  11. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  12. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  13. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  14. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  15. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  16. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  17. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  18. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  19. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  20. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  21. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  22. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  23. wandb/proto/v3/wandb_base_pb2.py +2 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  25. wandb/proto/v3/wandb_server_pb2.py +2 -1
  26. wandb/proto/v3/wandb_settings_pb2.py +2 -1
  27. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  28. wandb/proto/v4/wandb_base_pb2.py +2 -1
  29. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  30. wandb/proto/v4/wandb_server_pb2.py +2 -1
  31. wandb/proto/v4/wandb_settings_pb2.py +2 -1
  32. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  33. wandb/proto/v5/wandb_base_pb2.py +3 -2
  34. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  35. wandb/proto/v5/wandb_server_pb2.py +3 -2
  36. wandb/proto/v5/wandb_settings_pb2.py +3 -2
  37. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  38. wandb/sdk/data_types/audio.py +165 -0
  39. wandb/sdk/data_types/bokeh.py +70 -0
  40. wandb/sdk/data_types/graph.py +405 -0
  41. wandb/sdk/data_types/image.py +156 -0
  42. wandb/sdk/data_types/table.py +1204 -0
  43. wandb/sdk/data_types/trace_tree.py +2 -2
  44. wandb/sdk/data_types/utils.py +49 -0
  45. wandb/sdk/service/service.py +2 -9
  46. wandb/sdk/service/streams.py +0 -7
  47. wandb/sdk/wandb_init.py +10 -3
  48. wandb/sdk/wandb_run.py +6 -152
  49. wandb/sdk/wandb_setup.py +1 -1
  50. wandb/sklearn.py +35 -0
  51. wandb/util.py +6 -2
  52. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/METADATA +1 -1
  53. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/RECORD +61 -57
  54. wandb/sdk/lib/console.py +0 -39
  55. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  56. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  57. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  58. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  59. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  60. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/WHEEL +0 -0
  61. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/entry_points.txt +0 -0
  62. {wandb-0.18.0.dist-info → wandb-0.18.1.dist-info}/licenses/LICENSE +0 -0
wandb/data_types.py CHANGED
@@ -13,30 +13,10 @@ serialize to JSON, since that is what wandb uses to save the objects locally
13
13
  and upload them to the W&B server.
14
14
  """
15
15
 
16
- import base64
17
- import binascii
18
- import codecs
19
- import datetime
20
- import hashlib
21
- import json
22
- import logging
23
- import os
24
- import pprint
25
- from decimal import Decimal
26
- from typing import Optional
27
-
28
- import wandb
29
- from wandb import util
30
- from wandb.sdk.lib import filesystem
31
-
32
- from .sdk.data_types import _dtypes
33
- from .sdk.data_types._private import MEDIA_TMP
34
- from .sdk.data_types.base_types.media import (
35
- BatchableMedia,
36
- Media,
37
- _numpy_arrays_to_lists,
38
- )
16
+ from .sdk.data_types.audio import Audio
39
17
  from .sdk.data_types.base_types.wb_value import WBValue
18
+ from .sdk.data_types.bokeh import Bokeh
19
+ from .sdk.data_types.graph import Graph, Node
40
20
  from .sdk.data_types.helper_types.bounding_boxes_2d import BoundingBoxes2D
41
21
  from .sdk.data_types.helper_types.classes import Classes
42
22
  from .sdk.data_types.helper_types.image_mask import ImageMask
@@ -47,9 +27,9 @@ from .sdk.data_types.molecule import Molecule
47
27
  from .sdk.data_types.object_3d import Object3D, box3d
48
28
  from .sdk.data_types.plotly import Plotly
49
29
  from .sdk.data_types.saved_model import _SavedModel
30
+ from .sdk.data_types.table import JoinedTable, PartitionedTable, Table
50
31
  from .sdk.data_types.trace_tree import WBTraceTree
51
32
  from .sdk.data_types.video import Video
52
- from .sdk.lib import runid
53
33
 
54
34
  # Note: we are importing everything from the sdk/data_types to maintain a namespace for now.
55
35
  # Once we fully type this file and move it all into sdk, then we will need to clean up the
@@ -59,7 +39,11 @@ __all__ = [
59
39
  # Untyped Exports
60
40
  "Audio",
61
41
  "Table",
42
+ "JoinedTable",
43
+ "PartitionedTable",
62
44
  "Bokeh",
45
+ "Node",
46
+ "Graph",
63
47
  # Typed Exports
64
48
  "Histogram",
65
49
  "Html",
@@ -71,2003 +55,9 @@ __all__ = [
71
55
  "Video",
72
56
  "WBTraceTree",
73
57
  "_SavedModel",
58
+ "WBValue",
74
59
  # Typed Legacy Exports (I'd like to remove these)
75
60
  "ImageMask",
76
61
  "BoundingBoxes2D",
77
62
  "Classes",
78
63
  ]
79
-
80
-
81
- class _TableLinkMixin:
82
- def set_table(self, table):
83
- self._table = table
84
-
85
-
86
- class _TableKey(str, _TableLinkMixin):
87
- def set_table(self, table, col_name):
88
- assert col_name in table.columns
89
- self._table = table
90
- self._col_name = col_name
91
-
92
-
93
- class _TableIndex(int, _TableLinkMixin):
94
- def get_row(self):
95
- row = {}
96
- if self._table:
97
- row = {
98
- c: self._table.data[self][i] for i, c in enumerate(self._table.columns)
99
- }
100
-
101
- return row
102
-
103
-
104
- def _json_helper(val, artifact):
105
- if isinstance(val, WBValue):
106
- return val.to_json(artifact)
107
- elif val.__class__ is dict:
108
- res = {}
109
- for key in val:
110
- res[key] = _json_helper(val[key], artifact)
111
- return res
112
-
113
- if hasattr(val, "tolist"):
114
- py_val = val.tolist()
115
- if val.__class__.__name__ == "datetime64" and isinstance(py_val, int):
116
- # when numpy datetime64 .tolist() returns an int, it is nanoseconds.
117
- # need to convert to milliseconds
118
- return _json_helper(py_val / int(1e6), artifact)
119
- return _json_helper(py_val, artifact)
120
- elif hasattr(val, "item"):
121
- return _json_helper(val.item(), artifact)
122
-
123
- if isinstance(val, datetime.datetime):
124
- if val.tzinfo is None:
125
- val = datetime.datetime(
126
- val.year,
127
- val.month,
128
- val.day,
129
- val.hour,
130
- val.minute,
131
- val.second,
132
- val.microsecond,
133
- tzinfo=datetime.timezone.utc,
134
- )
135
- return int(val.timestamp() * 1000)
136
- elif isinstance(val, datetime.date):
137
- return int(
138
- datetime.datetime(
139
- val.year, val.month, val.day, tzinfo=datetime.timezone.utc
140
- ).timestamp()
141
- * 1000
142
- )
143
- elif isinstance(val, (list, tuple)):
144
- return [_json_helper(i, artifact) for i in val]
145
- elif isinstance(val, Decimal):
146
- return float(val)
147
- else:
148
- return util.json_friendly(val)[0]
149
-
150
-
151
- class Table(Media):
152
- """The Table class used to display and analyze tabular data.
153
-
154
- Unlike traditional spreadsheets, Tables support numerous types of data:
155
- scalar values, strings, numpy arrays, and most subclasses of `wandb.data_types.Media`.
156
- This means you can embed `Images`, `Video`, `Audio`, and other sorts of rich, annotated media
157
- directly in Tables, alongside other traditional scalar values.
158
-
159
- This class is the primary class used to generate the Table Visualizer
160
- in the UI: https://docs.wandb.ai/guides/data-vis/tables.
161
-
162
- Arguments:
163
- columns: (List[str]) Names of the columns in the table.
164
- Defaults to ["Input", "Output", "Expected"].
165
- data: (List[List[any]]) 2D row-oriented array of values.
166
- dataframe: (pandas.DataFrame) DataFrame object used to create the table.
167
- When set, `data` and `columns` arguments are ignored.
168
- optional: (Union[bool,List[bool]]) Determines if `None` values are allowed. Default to True
169
- - If a singular bool value, then the optionality is enforced for all
170
- columns specified at construction time
171
- - If a list of bool values, then the optionality is applied to each
172
- column - should be the same length as `columns`
173
- applies to all columns. A list of bool values applies to each respective column.
174
- allow_mixed_types: (bool) Determines if columns are allowed to have mixed types
175
- (disables type validation). Defaults to False
176
- """
177
-
178
- MAX_ROWS = 10000
179
- MAX_ARTIFACT_ROWS = 200000
180
- _MAX_EMBEDDING_DIMENSIONS = 150
181
- _log_type = "table"
182
-
183
- def __init__(
184
- self,
185
- columns=None,
186
- data=None,
187
- rows=None,
188
- dataframe=None,
189
- dtype=None,
190
- optional=True,
191
- allow_mixed_types=False,
192
- ):
193
- """Initializes a Table object.
194
-
195
- The rows is available for legacy reasons and should not be used.
196
- The Table class uses data to mimic the Pandas API.
197
- """
198
- super().__init__()
199
- self._pk_col = None
200
- self._fk_cols = set()
201
- if allow_mixed_types:
202
- dtype = _dtypes.AnyType
203
-
204
- # This is kept for legacy reasons (tss: personally, I think we should remove this)
205
- if columns is None:
206
- columns = ["Input", "Output", "Expected"]
207
-
208
- # Explicit dataframe option
209
- if dataframe is not None:
210
- self._init_from_dataframe(dataframe, columns, optional, dtype)
211
- else:
212
- # Expected pattern
213
- if data is not None:
214
- if util.is_numpy_array(data):
215
- self._init_from_ndarray(data, columns, optional, dtype)
216
- elif util.is_pandas_data_frame(data):
217
- self._init_from_dataframe(data, columns, optional, dtype)
218
- else:
219
- self._init_from_list(data, columns, optional, dtype)
220
-
221
- # legacy
222
- elif rows is not None:
223
- self._init_from_list(rows, columns, optional, dtype)
224
-
225
- # Default empty case
226
- else:
227
- self._init_from_list([], columns, optional, dtype)
228
-
229
- @staticmethod
230
- def _assert_valid_columns(columns):
231
- valid_col_types = [str, int]
232
- assert isinstance(columns, list), "columns argument expects a `list` object"
233
- assert len(columns) == 0 or all(
234
- [type(col) in valid_col_types for col in columns]
235
- ), "columns argument expects list of strings or ints"
236
-
237
- def _init_from_list(self, data, columns, optional=True, dtype=None):
238
- assert isinstance(data, list), "data argument expects a `list` object"
239
- self.data = []
240
- self._assert_valid_columns(columns)
241
- self.columns = columns
242
- self._make_column_types(dtype, optional)
243
- for row in data:
244
- self.add_data(*row)
245
-
246
- def _init_from_ndarray(self, ndarray, columns, optional=True, dtype=None):
247
- assert util.is_numpy_array(
248
- ndarray
249
- ), "ndarray argument expects a `numpy.ndarray` object"
250
- self.data = []
251
- self._assert_valid_columns(columns)
252
- self.columns = columns
253
- self._make_column_types(dtype, optional)
254
- for row in ndarray:
255
- self.add_data(*row)
256
-
257
- def _init_from_dataframe(self, dataframe, columns, optional=True, dtype=None):
258
- assert util.is_pandas_data_frame(
259
- dataframe
260
- ), "dataframe argument expects a `pandas.core.frame.DataFrame` object"
261
- self.data = []
262
- columns = list(dataframe.columns)
263
- self._assert_valid_columns(columns)
264
- self.columns = columns
265
- self._make_column_types(dtype, optional)
266
- for row in range(len(dataframe)):
267
- self.add_data(*tuple(dataframe[col].values[row] for col in self.columns))
268
-
269
- def _make_column_types(self, dtype=None, optional=True):
270
- if dtype is None:
271
- dtype = _dtypes.UnknownType()
272
-
273
- if optional.__class__ is not list:
274
- optional = [optional for _ in range(len(self.columns))]
275
-
276
- if dtype.__class__ is not list:
277
- dtype = [dtype for _ in range(len(self.columns))]
278
-
279
- self._column_types = _dtypes.TypedDictType({})
280
- for col_name, opt, dt in zip(self.columns, optional, dtype):
281
- self.cast(col_name, dt, opt)
282
-
283
- def cast(self, col_name, dtype, optional=False):
284
- """Casts a column to a specific data type.
285
-
286
- This can be one of the normal python classes, an internal W&B type, or an
287
- example object, like an instance of wandb.Image or wandb.Classes.
288
-
289
- Arguments:
290
- col_name: (str) - The name of the column to cast.
291
- dtype: (class, wandb.wandb_sdk.interface._dtypes.Type, any) - The target dtype.
292
- optional: (bool) - If the column should allow Nones.
293
- """
294
- assert col_name in self.columns
295
-
296
- wbtype = _dtypes.TypeRegistry.type_from_dtype(dtype)
297
-
298
- if optional:
299
- wbtype = _dtypes.OptionalType(wbtype)
300
-
301
- # Cast each value in the row, raising an error if there are invalid entries.
302
- col_ndx = self.columns.index(col_name)
303
- for row in self.data:
304
- result_type = wbtype.assign(row[col_ndx])
305
- if isinstance(result_type, _dtypes.InvalidType):
306
- raise TypeError(
307
- "Existing data {}, of type {} cannot be cast to {}".format(
308
- row[col_ndx],
309
- _dtypes.TypeRegistry.type_of(row[col_ndx]),
310
- wbtype,
311
- )
312
- )
313
- wbtype = result_type
314
-
315
- # Assert valid options
316
- is_pk = isinstance(wbtype, _PrimaryKeyType)
317
- is_fk = isinstance(wbtype, _ForeignKeyType)
318
- is_fi = isinstance(wbtype, _ForeignIndexType)
319
- if is_pk or is_fk or is_fi:
320
- assert (
321
- not optional
322
- ), "Primary keys, foreign keys, and foreign indexes cannot be optional."
323
-
324
- if (is_fk or is_fk) and id(wbtype.params["table"]) == id(self):
325
- raise AssertionError("Cannot set a foreign table reference to same table.")
326
-
327
- if is_pk:
328
- assert (
329
- self._pk_col is None
330
- ), "Cannot have multiple primary keys - {} is already set as the primary key.".format(
331
- self._pk_col
332
- )
333
-
334
- # Update the column type
335
- self._column_types.params["type_map"][col_name] = wbtype
336
-
337
- # Wrap the data if needed
338
- self._update_keys()
339
- return wbtype
340
-
341
- def __ne__(self, other):
342
- return not self.__eq__(other)
343
-
344
- def _eq_debug(self, other, should_assert=False):
345
- eq = isinstance(other, Table)
346
- assert not should_assert or eq, "Found type {}, expected {}".format(
347
- other.__class__, Table
348
- )
349
- eq = eq and len(self.data) == len(other.data)
350
- assert not should_assert or eq, "Found {} rows, expected {}".format(
351
- len(other.data), len(self.data)
352
- )
353
- eq = eq and self.columns == other.columns
354
- assert not should_assert or eq, "Found columns {}, expected {}".format(
355
- other.columns, self.columns
356
- )
357
- eq = eq and self._column_types == other._column_types
358
- assert (
359
- not should_assert or eq
360
- ), "Found column type {}, expected column type {}".format(
361
- other._column_types, self._column_types
362
- )
363
- if eq:
364
- for row_ndx in range(len(self.data)):
365
- for col_ndx in range(len(self.data[row_ndx])):
366
- _eq = self.data[row_ndx][col_ndx] == other.data[row_ndx][col_ndx]
367
- # equal if all are equal
368
- if util.is_numpy_array(_eq):
369
- _eq = ((_eq * -1) + 1).sum() == 0
370
- eq = eq and _eq
371
- assert (
372
- not should_assert or eq
373
- ), "Unequal data at row_ndx {} col_ndx {}: found {}, expected {}".format(
374
- row_ndx,
375
- col_ndx,
376
- other.data[row_ndx][col_ndx],
377
- self.data[row_ndx][col_ndx],
378
- )
379
- if not eq:
380
- return eq
381
- return eq
382
-
383
- def __eq__(self, other):
384
- return self._eq_debug(other)
385
-
386
- def add_row(self, *row):
387
- """Deprecated; use add_data instead."""
388
- logging.warning("add_row is deprecated, use add_data")
389
- self.add_data(*row)
390
-
391
- def add_data(self, *data):
392
- """Adds a new row of data to the table. The maximum amount of rows in a table is determined by `wandb.Table.MAX_ARTIFACT_ROWS`.
393
-
394
- The length of the data should match the length of the table column.
395
- """
396
- if len(data) != len(self.columns):
397
- raise ValueError(
398
- "This table expects {} columns: {}, found {}".format(
399
- len(self.columns), self.columns, len(data)
400
- )
401
- )
402
-
403
- # Special case to pre-emptively cast a column as a key.
404
- # Needed as String.assign(Key) is invalid
405
- for ndx, item in enumerate(data):
406
- if isinstance(item, _TableLinkMixin):
407
- self.cast(
408
- self.columns[ndx],
409
- _dtypes.TypeRegistry.type_of(item),
410
- optional=False,
411
- )
412
-
413
- # Update the table's column types
414
- result_type = self._get_updated_result_type(data)
415
- self._column_types = result_type
416
-
417
- # rows need to be mutable
418
- if isinstance(data, tuple):
419
- data = list(data)
420
- # Add the new data
421
- self.data.append(data)
422
-
423
- # Update the wrapper values if needed
424
- self._update_keys(force_last=True)
425
-
426
- def _get_updated_result_type(self, row):
427
- """Returns the updated result type based on the inputted row.
428
-
429
- Raises:
430
- TypeError: if the assignment is invalid.
431
- """
432
- incoming_row_dict = {
433
- col_key: row[ndx] for ndx, col_key in enumerate(self.columns)
434
- }
435
- current_type = self._column_types
436
- result_type = current_type.assign(incoming_row_dict)
437
- if isinstance(result_type, _dtypes.InvalidType):
438
- raise TypeError(
439
- "Data row contained incompatible types:\n{}".format(
440
- current_type.explain(incoming_row_dict)
441
- )
442
- )
443
- return result_type
444
-
445
- def _to_table_json(self, max_rows=None, warn=True):
446
- # separate this method for easier testing
447
- if max_rows is None:
448
- max_rows = Table.MAX_ROWS
449
- n_rows = len(self.data)
450
- if n_rows > max_rows and warn:
451
- if wandb.run and (
452
- wandb.run.settings.table_raise_on_max_row_limit_exceeded
453
- or wandb.run.settings.strict
454
- ):
455
- raise ValueError(
456
- f"Table row limit exceeded: table has {n_rows} rows, limit is {max_rows}. "
457
- f"To increase the maximum number of allowed rows in a wandb.Table, override "
458
- f"the limit with `wandb.Table.MAX_ARTIFACT_ROWS = X` and try again. Note: "
459
- f"this may cause slower queries in the W&B UI."
460
- )
461
- logging.warning("Truncating wandb.Table object to %i rows." % max_rows)
462
- return {"columns": self.columns, "data": self.data[:max_rows]}
463
-
464
- def bind_to_run(self, *args, **kwargs):
465
- # We set `warn=False` since Tables will now always be logged to both
466
- # files and artifacts. The file limit will never practically matter and
467
- # this code path will be ultimately removed. The 10k limit warning confuses
468
- # users given that we publicly say 200k is the limit.
469
- data = self._to_table_json(warn=False)
470
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".table.json")
471
- data = _numpy_arrays_to_lists(data)
472
- with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
473
- util.json_dump_safer(data, fp)
474
- self._set_file(tmp_path, is_tmp=True, extension=".table.json")
475
- super().bind_to_run(*args, **kwargs)
476
-
477
- @classmethod
478
- def get_media_subdir(cls):
479
- return os.path.join("media", "table")
480
-
481
- @classmethod
482
- def from_json(cls, json_obj, source_artifact):
483
- data = []
484
- column_types = None
485
- np_deserialized_columns = {}
486
- timestamp_column_indices = set()
487
- if json_obj.get("column_types") is not None:
488
- column_types = _dtypes.TypeRegistry.type_from_dict(
489
- json_obj["column_types"], source_artifact
490
- )
491
- for col_name in column_types.params["type_map"]:
492
- col_type = column_types.params["type_map"][col_name]
493
- ndarray_type = None
494
- if isinstance(col_type, _dtypes.NDArrayType):
495
- ndarray_type = col_type
496
- elif isinstance(col_type, _dtypes.UnionType):
497
- for t in col_type.params["allowed_types"]:
498
- if isinstance(t, _dtypes.NDArrayType):
499
- ndarray_type = t
500
- elif isinstance(t, _dtypes.TimestampType):
501
- timestamp_column_indices.add(
502
- json_obj["columns"].index(col_name)
503
- )
504
-
505
- elif isinstance(col_type, _dtypes.TimestampType):
506
- timestamp_column_indices.add(json_obj["columns"].index(col_name))
507
-
508
- if (
509
- ndarray_type is not None
510
- and ndarray_type._get_serialization_path() is not None
511
- ):
512
- serialization_path = ndarray_type._get_serialization_path()
513
- np = util.get_module(
514
- "numpy",
515
- required="Deserializing NumPy columns requires NumPy to be installed.",
516
- )
517
- deserialized = np.load(
518
- source_artifact.get_entry(serialization_path["path"]).download()
519
- )
520
- np_deserialized_columns[json_obj["columns"].index(col_name)] = (
521
- deserialized[serialization_path["key"]]
522
- )
523
- ndarray_type._clear_serialization_path()
524
-
525
- for r_ndx, row in enumerate(json_obj["data"]):
526
- row_data = []
527
- for c_ndx, item in enumerate(row):
528
- cell = item
529
- if c_ndx in timestamp_column_indices and isinstance(item, (int, float)):
530
- cell = datetime.datetime.fromtimestamp(
531
- item / 1000, tz=datetime.timezone.utc
532
- )
533
- elif c_ndx in np_deserialized_columns:
534
- cell = np_deserialized_columns[c_ndx][r_ndx]
535
- elif isinstance(item, dict) and "_type" in item:
536
- obj = WBValue.init_from_json(item, source_artifact)
537
- if obj is not None:
538
- cell = obj
539
- row_data.append(cell)
540
- data.append(row_data)
541
-
542
- # construct Table with dtypes for each column if type information exists
543
- dtypes = None
544
- if column_types is not None:
545
- dtypes = [
546
- column_types.params["type_map"][str(col)] for col in json_obj["columns"]
547
- ]
548
-
549
- new_obj = cls(columns=json_obj["columns"], data=data, dtype=dtypes)
550
-
551
- if column_types is not None:
552
- new_obj._column_types = column_types
553
-
554
- new_obj._update_keys()
555
- return new_obj
556
-
557
- def to_json(self, run_or_artifact):
558
- json_dict = super().to_json(run_or_artifact)
559
-
560
- if isinstance(run_or_artifact, wandb.wandb_sdk.wandb_run.Run):
561
- json_dict.update(
562
- {
563
- "_type": "table-file",
564
- "ncols": len(self.columns),
565
- "nrows": len(self.data),
566
- }
567
- )
568
-
569
- elif isinstance(run_or_artifact, wandb.Artifact):
570
- artifact = run_or_artifact
571
- mapped_data = []
572
- data = self._to_table_json(Table.MAX_ARTIFACT_ROWS)["data"]
573
-
574
- ndarray_col_ndxs = set()
575
- for col_ndx, col_name in enumerate(self.columns):
576
- col_type = self._column_types.params["type_map"][col_name]
577
- ndarray_type = None
578
- if isinstance(col_type, _dtypes.NDArrayType):
579
- ndarray_type = col_type
580
- elif isinstance(col_type, _dtypes.UnionType):
581
- for t in col_type.params["allowed_types"]:
582
- if isinstance(t, _dtypes.NDArrayType):
583
- ndarray_type = t
584
-
585
- # Do not serialize 1d arrays - these are likely embeddings and
586
- # will not have the same cost as higher dimensional arrays
587
- is_1d_array = (
588
- ndarray_type is not None
589
- and "shape" in ndarray_type._params
590
- and isinstance(ndarray_type._params["shape"], list)
591
- and len(ndarray_type._params["shape"]) == 1
592
- and ndarray_type._params["shape"][0]
593
- <= self._MAX_EMBEDDING_DIMENSIONS
594
- )
595
- if is_1d_array:
596
- self._column_types.params["type_map"][col_name] = _dtypes.ListType(
597
- _dtypes.NumberType, ndarray_type._params["shape"][0]
598
- )
599
- elif ndarray_type is not None:
600
- np = util.get_module(
601
- "numpy",
602
- required="Serializing NumPy requires NumPy to be installed.",
603
- )
604
- file_name = f"{str(col_name)}_{runid.generate_id()}.npz"
605
- npz_file_name = os.path.join(MEDIA_TMP.name, file_name)
606
- np.savez_compressed(
607
- npz_file_name,
608
- **{
609
- str(col_name): self.get_column(col_name, convert_to="numpy")
610
- },
611
- )
612
- entry = artifact.add_file(
613
- npz_file_name, "media/serialized_data/" + file_name, is_tmp=True
614
- )
615
- ndarray_type._set_serialization_path(entry.path, str(col_name))
616
- ndarray_col_ndxs.add(col_ndx)
617
-
618
- for row in data:
619
- mapped_row = []
620
- for ndx, v in enumerate(row):
621
- if ndx in ndarray_col_ndxs:
622
- mapped_row.append(None)
623
- else:
624
- mapped_row.append(_json_helper(v, artifact))
625
- mapped_data.append(mapped_row)
626
-
627
- json_dict.update(
628
- {
629
- "_type": Table._log_type,
630
- "columns": self.columns,
631
- "data": mapped_data,
632
- "ncols": len(self.columns),
633
- "nrows": len(mapped_data),
634
- "column_types": self._column_types.to_json(artifact),
635
- }
636
- )
637
- else:
638
- raise ValueError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
639
-
640
- return json_dict
641
-
642
- def iterrows(self):
643
- """Returns the table data by row, showing the index of the row and the relevant data.
644
-
645
- Yields:
646
- ------
647
- index : int
648
- The index of the row. Using this value in other W&B tables
649
- will automatically build a relationship between the tables
650
- row : List[any]
651
- The data of the row.
652
- """
653
- for ndx in range(len(self.data)):
654
- index = _TableIndex(ndx)
655
- index.set_table(self)
656
- yield index, self.data[ndx]
657
-
658
- def set_pk(self, col_name):
659
- # TODO: Docs
660
- assert col_name in self.columns
661
- self.cast(col_name, _PrimaryKeyType())
662
-
663
- def set_fk(self, col_name, table, table_col):
664
- # TODO: Docs
665
- assert col_name in self.columns
666
- assert col_name != self._pk_col
667
- self.cast(col_name, _ForeignKeyType(table, table_col))
668
-
669
- def _update_keys(self, force_last=False):
670
- """Updates the known key-like columns based on current column types.
671
-
672
- If the state has been updated since the last update, wraps the data
673
- appropriately in the Key classes.
674
-
675
- Arguments:
676
- force_last: (bool) Wraps the last column of data even if there
677
- are no key updates.
678
- """
679
- _pk_col = None
680
- _fk_cols = set()
681
-
682
- # Buildup the known keys from column types
683
- c_types = self._column_types.params["type_map"]
684
- for t in c_types:
685
- if isinstance(c_types[t], _PrimaryKeyType):
686
- _pk_col = t
687
- elif isinstance(c_types[t], _ForeignKeyType) or isinstance(
688
- c_types[t], _ForeignIndexType
689
- ):
690
- _fk_cols.add(t)
691
-
692
- # If there are updates to perform, safely update them
693
- has_update = _pk_col != self._pk_col or _fk_cols != self._fk_cols
694
- if has_update:
695
- # If we removed the PK
696
- if _pk_col is None and self._pk_col is not None:
697
- raise AssertionError(
698
- f"Cannot unset primary key (column {self._pk_col})"
699
- )
700
- # If there is a removed FK
701
- if len(self._fk_cols - _fk_cols) > 0:
702
- raise AssertionError(
703
- "Cannot unset foreign key. Attempted to unset ({})".format(
704
- self._fk_cols - _fk_cols
705
- )
706
- )
707
-
708
- self._pk_col = _pk_col
709
- self._fk_cols = _fk_cols
710
-
711
- # Apply updates to data only if there are update or the caller
712
- # requested the final row to be updated
713
- if has_update or force_last:
714
- self._apply_key_updates(not has_update)
715
-
716
- def _apply_key_updates(self, only_last=False):
717
- """Appropriately wraps the underlying data in special Key classes.
718
-
719
- Arguments:
720
- only_last: only apply the updates to the last row (used for performance when
721
- the caller knows that the only new data is the last row and no updates were
722
- applied to the column types)
723
- """
724
- c_types = self._column_types.params["type_map"]
725
-
726
- # Define a helper function which will wrap the data of a single row
727
- # in the appropriate class wrapper.
728
- def update_row(row_ndx):
729
- for fk_col in self._fk_cols:
730
- col_ndx = self.columns.index(fk_col)
731
-
732
- # Wrap the Foreign Keys
733
- if isinstance(c_types[fk_col], _ForeignKeyType) and not isinstance(
734
- self.data[row_ndx][col_ndx], _TableKey
735
- ):
736
- self.data[row_ndx][col_ndx] = _TableKey(self.data[row_ndx][col_ndx])
737
- self.data[row_ndx][col_ndx].set_table(
738
- c_types[fk_col].params["table"],
739
- c_types[fk_col].params["col_name"],
740
- )
741
-
742
- # Wrap the Foreign Indexes
743
- elif isinstance(c_types[fk_col], _ForeignIndexType) and not isinstance(
744
- self.data[row_ndx][col_ndx], _TableIndex
745
- ):
746
- self.data[row_ndx][col_ndx] = _TableIndex(
747
- self.data[row_ndx][col_ndx]
748
- )
749
- self.data[row_ndx][col_ndx].set_table(
750
- c_types[fk_col].params["table"]
751
- )
752
-
753
- # Wrap the Primary Key
754
- if self._pk_col is not None:
755
- col_ndx = self.columns.index(self._pk_col)
756
- self.data[row_ndx][col_ndx] = _TableKey(self.data[row_ndx][col_ndx])
757
- self.data[row_ndx][col_ndx].set_table(self, self._pk_col)
758
-
759
- if only_last:
760
- update_row(len(self.data) - 1)
761
- else:
762
- for row_ndx in range(len(self.data)):
763
- update_row(row_ndx)
764
-
765
- def add_column(self, name, data, optional=False):
766
- """Adds a column of data to the table.
767
-
768
- Arguments:
769
- name: (str) - the unique name of the column
770
- data: (list | np.array) - a column of homogeneous data
771
- optional: (bool) - if null-like values are permitted
772
- """
773
- assert isinstance(name, str) and name not in self.columns
774
- is_np = util.is_numpy_array(data)
775
- assert isinstance(data, list) or is_np
776
- assert isinstance(optional, bool)
777
- is_first_col = len(self.columns) == 0
778
- assert is_first_col or len(data) == len(
779
- self.data
780
- ), f"Expected length {len(self.data)}, found {len(data)}"
781
-
782
- # Add the new data
783
- for ndx in range(max(len(data), len(self.data))):
784
- if is_first_col:
785
- self.data.append([])
786
- if is_np:
787
- self.data[ndx].append(data[ndx])
788
- else:
789
- self.data[ndx].append(data[ndx])
790
- # add the column
791
- self.columns.append(name)
792
-
793
- try:
794
- self.cast(name, _dtypes.UnknownType(), optional=optional)
795
- except TypeError as err:
796
- # Undo the changes
797
- if is_first_col:
798
- self.data = []
799
- self.columns = []
800
- else:
801
- for ndx in range(len(self.data)):
802
- self.data[ndx] = self.data[ndx][:-1]
803
- self.columns = self.columns[:-1]
804
- raise err
805
-
806
- def get_column(self, name, convert_to=None):
807
- """Retrieves a column from the table and optionally converts it to a NumPy object.
808
-
809
- Arguments:
810
- name: (str) - the name of the column
811
- convert_to: (str, optional)
812
- - "numpy": will convert the underlying data to numpy object
813
- """
814
- assert name in self.columns
815
- assert convert_to is None or convert_to == "numpy"
816
- if convert_to == "numpy":
817
- np = util.get_module(
818
- "numpy", required="Converting to NumPy requires installing NumPy"
819
- )
820
- col = []
821
- col_ndx = self.columns.index(name)
822
- for row in self.data:
823
- item = row[col_ndx]
824
- if convert_to is not None and isinstance(item, WBValue):
825
- item = item.to_data_array()
826
- col.append(item)
827
- if convert_to == "numpy":
828
- col = np.array(col)
829
- return col
830
-
831
- def get_index(self):
832
- """Returns an array of row indexes for use in other tables to create links."""
833
- ndxs = []
834
- for ndx in range(len(self.data)):
835
- index = _TableIndex(ndx)
836
- index.set_table(self)
837
- ndxs.append(index)
838
- return ndxs
839
-
840
- def get_dataframe(self):
841
- """Returns a `pandas.DataFrame` of the table."""
842
- pd = util.get_module(
843
- "pandas",
844
- required="Converting to pandas.DataFrame requires installing pandas",
845
- )
846
- return pd.DataFrame.from_records(self.data, columns=self.columns)
847
-
848
- def index_ref(self, index):
849
- """Gets a reference of the index of a row in the table."""
850
- assert index < len(self.data)
851
- _index = _TableIndex(index)
852
- _index.set_table(self)
853
- return _index
854
-
855
- def add_computed_columns(self, fn):
856
- """Adds one or more computed columns based on existing data.
857
-
858
- Args:
859
- fn: A function which accepts one or two parameters, ndx (int) and row (dict),
860
- which is expected to return a dict representing new columns for that row, keyed
861
- by the new column names.
862
-
863
- `ndx` is an integer representing the index of the row. Only included if `include_ndx`
864
- is set to `True`.
865
-
866
- `row` is a dictionary keyed by existing columns
867
- """
868
- new_columns = {}
869
- for ndx, row in self.iterrows():
870
- row_dict = {self.columns[i]: row[i] for i in range(len(self.columns))}
871
- new_row_dict = fn(ndx, row_dict)
872
- assert isinstance(new_row_dict, dict)
873
- for key in new_row_dict:
874
- new_columns[key] = new_columns.get(key, [])
875
- new_columns[key].append(new_row_dict[key])
876
- for new_col_name in new_columns:
877
- self.add_column(new_col_name, new_columns[new_col_name])
878
-
879
-
880
- class _PartitionTablePartEntry:
881
- """Helper class for PartitionTable to track its parts."""
882
-
883
- def __init__(self, entry, source_artifact):
884
- self.entry = entry
885
- self.source_artifact = source_artifact
886
- self._part = None
887
-
888
- def get_part(self):
889
- if self._part is None:
890
- self._part = self.source_artifact.get(self.entry.path)
891
- return self._part
892
-
893
- def free(self):
894
- self._part = None
895
-
896
-
897
- class PartitionedTable(Media):
898
- """A table which is composed of multiple sub-tables.
899
-
900
- Currently, PartitionedTable is designed to point to a directory within an artifact.
901
- """
902
-
903
- _log_type = "partitioned-table"
904
-
905
- def __init__(self, parts_path):
906
- """Initialize a PartitionedTable.
907
-
908
- Args:
909
- parts_path (str): path to a directory of tables in the artifact.
910
- """
911
- super().__init__()
912
- self.parts_path = parts_path
913
- self._loaded_part_entries = {}
914
-
915
- def to_json(self, artifact_or_run):
916
- json_obj = {
917
- "_type": PartitionedTable._log_type,
918
- }
919
- if isinstance(artifact_or_run, wandb.wandb_sdk.wandb_run.Run):
920
- artifact_entry_url = self._get_artifact_entry_ref_url()
921
- if artifact_entry_url is None:
922
- raise ValueError(
923
- "PartitionedTables must first be added to an Artifact before logging to a Run"
924
- )
925
- json_obj["artifact_path"] = artifact_entry_url
926
- else:
927
- json_obj["parts_path"] = self.parts_path
928
- return json_obj
929
-
930
- @classmethod
931
- def from_json(cls, json_obj, source_artifact):
932
- instance = cls(json_obj["parts_path"])
933
- entries = source_artifact.manifest.get_entries_in_directory(
934
- json_obj["parts_path"]
935
- )
936
- for entry in entries:
937
- instance._add_part_entry(entry, source_artifact)
938
- return instance
939
-
940
- def iterrows(self):
941
- """Iterate over rows as (ndx, row).
942
-
943
- Yields:
944
- ------
945
- index : int
946
- The index of the row.
947
- row : List[any]
948
- The data of the row.
949
- """
950
- columns = None
951
- ndx = 0
952
- for entry_path in self._loaded_part_entries:
953
- part = self._loaded_part_entries[entry_path].get_part()
954
- if columns is None:
955
- columns = part.columns
956
- elif columns != part.columns:
957
- raise ValueError(
958
- "Table parts have non-matching columns. {} != {}".format(
959
- columns, part.columns
960
- )
961
- )
962
- for _, row in part.iterrows():
963
- yield ndx, row
964
- ndx += 1
965
-
966
- self._loaded_part_entries[entry_path].free()
967
-
968
- def _add_part_entry(self, entry, source_artifact):
969
- self._loaded_part_entries[entry.path] = _PartitionTablePartEntry(
970
- entry, source_artifact
971
- )
972
-
973
- def __ne__(self, other):
974
- return not self.__eq__(other)
975
-
976
- def __eq__(self, other):
977
- return isinstance(other, self.__class__) and self.parts_path == other.parts_path
978
-
979
- def bind_to_run(self, *args, **kwargs):
980
- raise ValueError("PartitionedTables cannot be bound to runs")
981
-
982
-
983
- class Audio(BatchableMedia):
984
- """Wandb class for audio clips.
985
-
986
- Arguments:
987
- data_or_path: (string or numpy array) A path to an audio file
988
- or a numpy array of audio data.
989
- sample_rate: (int) Sample rate, required when passing in raw
990
- numpy array of audio data.
991
- caption: (string) Caption to display with audio.
992
- """
993
-
994
- _log_type = "audio-file"
995
-
996
- def __init__(self, data_or_path, sample_rate=None, caption=None):
997
- """Accept a path to an audio file or a numpy array of audio data."""
998
- super().__init__()
999
- self._duration = None
1000
- self._sample_rate = sample_rate
1001
- self._caption = caption
1002
-
1003
- if isinstance(data_or_path, str):
1004
- if self.path_is_reference(data_or_path):
1005
- self._path = data_or_path
1006
- self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
1007
- self._is_tmp = False
1008
- else:
1009
- self._set_file(data_or_path, is_tmp=False)
1010
- else:
1011
- if sample_rate is None:
1012
- raise ValueError(
1013
- 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
1014
- )
1015
-
1016
- soundfile = util.get_module(
1017
- "soundfile",
1018
- required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
1019
- )
1020
-
1021
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
1022
- soundfile.write(tmp_path, data_or_path, sample_rate)
1023
- self._duration = len(data_or_path) / float(sample_rate)
1024
-
1025
- self._set_file(tmp_path, is_tmp=True)
1026
-
1027
- @classmethod
1028
- def get_media_subdir(cls):
1029
- return os.path.join("media", "audio")
1030
-
1031
- @classmethod
1032
- def from_json(cls, json_obj, source_artifact):
1033
- return cls(
1034
- source_artifact.get_entry(json_obj["path"]).download(),
1035
- caption=json_obj["caption"],
1036
- )
1037
-
1038
- def bind_to_run(
1039
- self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
1040
- ):
1041
- if self.path_is_reference(self._path):
1042
- raise ValueError(
1043
- "Audio media created by a reference to external storage cannot currently be added to a run"
1044
- )
1045
-
1046
- return super().bind_to_run(run, key, step, id_, ignore_copy_err)
1047
-
1048
- def to_json(self, run):
1049
- json_dict = super().to_json(run)
1050
- json_dict.update(
1051
- {
1052
- "_type": self._log_type,
1053
- "caption": self._caption,
1054
- }
1055
- )
1056
- return json_dict
1057
-
1058
- @classmethod
1059
- def seq_to_json(cls, seq, run, key, step):
1060
- audio_list = list(seq)
1061
-
1062
- util.get_module(
1063
- "soundfile",
1064
- required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
1065
- )
1066
- base_path = os.path.join(run.dir, "media", "audio")
1067
- filesystem.mkdir_exists_ok(base_path)
1068
- meta = {
1069
- "_type": "audio",
1070
- "count": len(audio_list),
1071
- "audio": [a.to_json(run) for a in audio_list],
1072
- }
1073
- sample_rates = cls.sample_rates(audio_list)
1074
- if sample_rates:
1075
- meta["sampleRates"] = sample_rates
1076
- durations = cls.durations(audio_list)
1077
- if durations:
1078
- meta["durations"] = durations
1079
- captions = cls.captions(audio_list)
1080
- if captions:
1081
- meta["captions"] = captions
1082
-
1083
- return meta
1084
-
1085
- @classmethod
1086
- def durations(cls, audio_list):
1087
- return [a._duration for a in audio_list]
1088
-
1089
- @classmethod
1090
- def sample_rates(cls, audio_list):
1091
- return [a._sample_rate for a in audio_list]
1092
-
1093
- @classmethod
1094
- def captions(cls, audio_list):
1095
- captions = [a._caption for a in audio_list]
1096
- if all(c is None for c in captions):
1097
- return False
1098
- else:
1099
- return ["" if c is None else c for c in captions]
1100
-
1101
- def resolve_ref(self):
1102
- if self.path_is_reference(self._path):
1103
- # this object was already created using a ref:
1104
- return self._path
1105
- source_artifact = self._artifact_source.artifact
1106
-
1107
- resolved_name = source_artifact._local_path_to_name(self._path)
1108
- if resolved_name is not None:
1109
- target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
1110
- if target_entry is not None:
1111
- return target_entry.ref
1112
-
1113
- return None
1114
-
1115
- def __eq__(self, other):
1116
- if self.path_is_reference(self._path) or self.path_is_reference(other._path):
1117
- # one or more of these objects is an unresolved reference -- we'll compare
1118
- # their reference paths instead of their SHAs:
1119
- return (
1120
- self.resolve_ref() == other.resolve_ref()
1121
- and self._caption == other._caption
1122
- )
1123
-
1124
- return super().__eq__(other) and self._caption == other._caption
1125
-
1126
- def __ne__(self, other):
1127
- return not self.__eq__(other)
1128
-
1129
-
1130
- class JoinedTable(Media):
1131
- """Join two tables for visualization in the Artifact UI.
1132
-
1133
- Arguments:
1134
- table1 (str, wandb.Table, ArtifactManifestEntry):
1135
- the path to a wandb.Table in an artifact, the table object, or ArtifactManifestEntry
1136
- table2 (str, wandb.Table):
1137
- the path to a wandb.Table in an artifact, the table object, or ArtifactManifestEntry
1138
- join_key (str, [str, str]):
1139
- key or keys to perform the join
1140
- """
1141
-
1142
- _log_type = "joined-table"
1143
-
1144
- def __init__(self, table1, table2, join_key):
1145
- super().__init__()
1146
-
1147
- if not isinstance(join_key, str) and (
1148
- not isinstance(join_key, list) or len(join_key) != 2
1149
- ):
1150
- raise ValueError(
1151
- "JoinedTable join_key should be a string or a list of two strings"
1152
- )
1153
-
1154
- if not self._validate_table_input(table1):
1155
- raise ValueError(
1156
- "JoinedTable table1 should be an artifact path to a table or wandb.Table object"
1157
- )
1158
-
1159
- if not self._validate_table_input(table2):
1160
- raise ValueError(
1161
- "JoinedTable table2 should be an artifact path to a table or wandb.Table object"
1162
- )
1163
-
1164
- self._table1 = table1
1165
- self._table2 = table2
1166
- self._join_key = join_key
1167
-
1168
- @classmethod
1169
- def from_json(cls, json_obj, source_artifact):
1170
- t1 = source_artifact.get(json_obj["table1"])
1171
- if t1 is None:
1172
- t1 = json_obj["table1"]
1173
-
1174
- t2 = source_artifact.get(json_obj["table2"])
1175
- if t2 is None:
1176
- t2 = json_obj["table2"]
1177
-
1178
- return cls(
1179
- t1,
1180
- t2,
1181
- json_obj["join_key"],
1182
- )
1183
-
1184
- @staticmethod
1185
- def _validate_table_input(table):
1186
- """Helper method to validate that the table input is one of the 3 supported types."""
1187
- return (
1188
- (isinstance(table, str) and table.endswith(".table.json"))
1189
- or isinstance(table, Table)
1190
- or isinstance(table, PartitionedTable)
1191
- or (hasattr(table, "ref_url") and table.ref_url().endswith(".table.json"))
1192
- )
1193
-
1194
- def _ensure_table_in_artifact(self, table, artifact, table_ndx):
1195
- """Helper method to add the table to the incoming artifact. Returns the path."""
1196
- if isinstance(table, Table) or isinstance(table, PartitionedTable):
1197
- table_name = f"t{table_ndx}_{str(id(self))}"
1198
- if (
1199
- table._artifact_source is not None
1200
- and table._artifact_source.name is not None
1201
- ):
1202
- table_name = os.path.basename(table._artifact_source.name)
1203
- entry = artifact.add(table, table_name)
1204
- table = entry.path
1205
- # Check if this is an ArtifactManifestEntry
1206
- elif hasattr(table, "ref_url"):
1207
- # Give the new object a unique, yet deterministic name
1208
- name = binascii.hexlify(base64.standard_b64decode(table.digest)).decode(
1209
- "ascii"
1210
- )[:20]
1211
- entry = artifact.add_reference(
1212
- table.ref_url(), "{}.{}.json".format(name, table.name.split(".")[-2])
1213
- )[0]
1214
- table = entry.path
1215
-
1216
- err_str = "JoinedTable table:{} not found in artifact. Add a table to the artifact using Artifact#add(<table>, {}) before adding this JoinedTable"
1217
- if table not in artifact._manifest.entries:
1218
- raise ValueError(err_str.format(table, table))
1219
-
1220
- return table
1221
-
1222
- def to_json(self, artifact_or_run):
1223
- json_obj = {
1224
- "_type": JoinedTable._log_type,
1225
- }
1226
- if isinstance(artifact_or_run, wandb.wandb_sdk.wandb_run.Run):
1227
- artifact_entry_url = self._get_artifact_entry_ref_url()
1228
- if artifact_entry_url is None:
1229
- raise ValueError(
1230
- "JoinedTables must first be added to an Artifact before logging to a Run"
1231
- )
1232
- json_obj["artifact_path"] = artifact_entry_url
1233
- else:
1234
- table1 = self._ensure_table_in_artifact(self._table1, artifact_or_run, 1)
1235
- table2 = self._ensure_table_in_artifact(self._table2, artifact_or_run, 2)
1236
- json_obj.update(
1237
- {
1238
- "table1": table1,
1239
- "table2": table2,
1240
- "join_key": self._join_key,
1241
- }
1242
- )
1243
- return json_obj
1244
-
1245
- def __ne__(self, other):
1246
- return not self.__eq__(other)
1247
-
1248
- def _eq_debug(self, other, should_assert=False):
1249
- eq = isinstance(other, JoinedTable)
1250
- assert not should_assert or eq, "Found type {}, expected {}".format(
1251
- other.__class__, JoinedTable
1252
- )
1253
- eq = eq and self._join_key == other._join_key
1254
- assert not should_assert or eq, "Found {} join key, expected {}".format(
1255
- other._join_key, self._join_key
1256
- )
1257
- eq = eq and self._table1._eq_debug(other._table1, should_assert)
1258
- eq = eq and self._table2._eq_debug(other._table2, should_assert)
1259
- return eq
1260
-
1261
- def __eq__(self, other):
1262
- return self._eq_debug(other, False)
1263
-
1264
- def bind_to_run(self, *args, **kwargs):
1265
- raise ValueError("JoinedTables cannot be bound to runs")
1266
-
1267
-
1268
- class Bokeh(Media):
1269
- """Wandb class for Bokeh plots.
1270
-
1271
- Arguments:
1272
- val: Bokeh plot
1273
- """
1274
-
1275
- _log_type = "bokeh-file"
1276
-
1277
- def __init__(self, data_or_path):
1278
- super().__init__()
1279
- bokeh = util.get_module("bokeh", required=True)
1280
- if isinstance(data_or_path, str) and os.path.exists(data_or_path):
1281
- with open(data_or_path) as file:
1282
- b_json = json.load(file)
1283
- self.b_obj = bokeh.document.Document.from_json(b_json)
1284
- self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
1285
- elif isinstance(data_or_path, bokeh.model.Model):
1286
- _data = bokeh.document.Document()
1287
- _data.add_root(data_or_path)
1288
- # serialize/deserialize pairing followed by sorting attributes ensures
1289
- # that the file's sha's are equivalent in subsequent calls
1290
- self.b_obj = bokeh.document.Document.from_json(_data.to_json())
1291
- b_json = self.b_obj.to_json()
1292
- if "references" in b_json["roots"]:
1293
- b_json["roots"]["references"].sort(key=lambda x: x["id"])
1294
-
1295
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
1296
- with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
1297
- util.json_dump_safer(b_json, fp)
1298
- self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
1299
- elif not isinstance(data_or_path, bokeh.document.Document):
1300
- raise TypeError(
1301
- "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
1302
- )
1303
-
1304
- def get_media_subdir(self):
1305
- return os.path.join("media", "bokeh")
1306
-
1307
- def to_json(self, run):
1308
- # TODO: (tss) this is getting redundant for all the media objects. We can probably
1309
- # pull this into Media#to_json and remove this type override for all the media types.
1310
- # There are only a few cases where the type is different between artifacts and runs.
1311
- json_dict = super().to_json(run)
1312
- json_dict["_type"] = self._log_type
1313
- return json_dict
1314
-
1315
- @classmethod
1316
- def from_json(cls, json_obj, source_artifact):
1317
- return cls(source_artifact.get_entry(json_obj["path"]).download())
1318
-
1319
-
1320
- def _nest(thing):
1321
- # Use tensorflows nest function if available, otherwise just wrap object in an array"""
1322
-
1323
- tfutil = util.get_module("tensorflow.python.util")
1324
- if tfutil:
1325
- return tfutil.nest.flatten(thing)
1326
- else:
1327
- return [thing]
1328
-
1329
-
1330
- class Graph(Media):
1331
- """Wandb class for graphs.
1332
-
1333
- This class is typically used for saving and displaying neural net models. It
1334
- represents the graph as an array of nodes and edges. The nodes can have
1335
- labels that can be visualized by wandb.
1336
-
1337
- Examples:
1338
- Import a keras model:
1339
- ```
1340
- Graph.from_keras(keras_model)
1341
- ```
1342
-
1343
- Attributes:
1344
- format (string): Format to help wandb display the graph nicely.
1345
- nodes ([wandb.Node]): List of wandb.Nodes
1346
- nodes_by_id (dict): dict of ids -> nodes
1347
- edges ([(wandb.Node, wandb.Node)]): List of pairs of nodes interpreted as edges
1348
- loaded (boolean): Flag to tell whether the graph is completely loaded
1349
- root (wandb.Node): root node of the graph
1350
- """
1351
-
1352
- _log_type = "graph-file"
1353
-
1354
- def __init__(self, format="keras"):
1355
- super().__init__()
1356
- # LB: TODO: I think we should factor criterion and criterion_passed out
1357
- self.format = format
1358
- self.nodes = []
1359
- self.nodes_by_id = {}
1360
- self.edges = []
1361
- self.loaded = False
1362
- self.criterion = None
1363
- self.criterion_passed = False
1364
- self.root = None # optional root Node if applicable
1365
-
1366
- def _to_graph_json(self, run=None):
1367
- # Needs to be its own function for tests
1368
- return {
1369
- "format": self.format,
1370
- "nodes": [node.to_json() for node in self.nodes],
1371
- "edges": [edge.to_json() for edge in self.edges],
1372
- }
1373
-
1374
- def bind_to_run(self, *args, **kwargs):
1375
- data = self._to_graph_json()
1376
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".graph.json")
1377
- data = _numpy_arrays_to_lists(data)
1378
- with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
1379
- util.json_dump_safer(data, fp)
1380
- self._set_file(tmp_path, is_tmp=True, extension=".graph.json")
1381
- if self.is_bound():
1382
- return
1383
- super().bind_to_run(*args, **kwargs)
1384
-
1385
- @classmethod
1386
- def get_media_subdir(cls):
1387
- return os.path.join("media", "graph")
1388
-
1389
- def to_json(self, run):
1390
- json_dict = super().to_json(run)
1391
- json_dict["_type"] = self._log_type
1392
- return json_dict
1393
-
1394
- def __getitem__(self, nid):
1395
- return self.nodes_by_id[nid]
1396
-
1397
- def pprint(self):
1398
- for edge in self.edges:
1399
- pprint.pprint(edge.attributes)
1400
- for node in self.nodes:
1401
- pprint.pprint(node.attributes)
1402
-
1403
- def add_node(self, node=None, **node_kwargs):
1404
- if node is None:
1405
- node = Node(**node_kwargs)
1406
- elif node_kwargs:
1407
- raise ValueError(
1408
- f"Only pass one of either node ({node}) or other keyword arguments ({node_kwargs})"
1409
- )
1410
- self.nodes.append(node)
1411
- self.nodes_by_id[node.id] = node
1412
-
1413
- return node
1414
-
1415
- def add_edge(self, from_node, to_node):
1416
- edge = Edge(from_node, to_node)
1417
- self.edges.append(edge)
1418
-
1419
- return edge
1420
-
1421
- @classmethod
1422
- def from_keras(cls, model):
1423
- # TODO: his method requires a refactor to work with the keras 3.
1424
- graph = cls()
1425
- # Shamelessly copied (then modified) from keras/keras/utils/layer_utils.py
1426
- sequential_like = cls._is_sequential(model)
1427
-
1428
- relevant_nodes = None
1429
- if not sequential_like:
1430
- relevant_nodes = []
1431
- for v in model._nodes_by_depth.values():
1432
- relevant_nodes += v
1433
-
1434
- layers = model.layers
1435
- for i in range(len(layers)):
1436
- node = Node.from_keras(layers[i])
1437
- if hasattr(layers[i], "_inbound_nodes"):
1438
- for in_node in layers[i]._inbound_nodes:
1439
- if relevant_nodes and in_node not in relevant_nodes:
1440
- # node is not part of the current network
1441
- continue
1442
- for in_layer in _nest(in_node.inbound_layers):
1443
- inbound_keras_node = Node.from_keras(in_layer)
1444
-
1445
- if inbound_keras_node.id not in graph.nodes_by_id:
1446
- graph.add_node(inbound_keras_node)
1447
- inbound_node = graph.nodes_by_id[inbound_keras_node.id]
1448
-
1449
- graph.add_edge(inbound_node, node)
1450
- graph.add_node(node)
1451
- return graph
1452
-
1453
- @classmethod
1454
- def _is_sequential(cls, model):
1455
- sequential_like = True
1456
-
1457
- if (
1458
- model.__class__.__name__ != "Sequential"
1459
- and hasattr(model, "_is_graph_network")
1460
- and model._is_graph_network
1461
- ):
1462
- nodes_by_depth = model._nodes_by_depth.values()
1463
- nodes = []
1464
- for v in nodes_by_depth:
1465
- # TensorFlow2 doesn't insure inbound is always a list
1466
- inbound = v[0].inbound_layers
1467
- if not hasattr(inbound, "__len__"):
1468
- inbound = [inbound]
1469
- if (len(v) > 1) or (len(v) == 1 and len(inbound) > 1):
1470
- # if the model has multiple nodes
1471
- # or if the nodes have multiple inbound_layers
1472
- # the model is no longer sequential
1473
- sequential_like = False
1474
- break
1475
- nodes += v
1476
- if sequential_like:
1477
- # search for shared layers
1478
- for layer in model.layers:
1479
- flag = False
1480
- if hasattr(layer, "_inbound_nodes"):
1481
- for node in layer._inbound_nodes:
1482
- if node in nodes:
1483
- if flag:
1484
- sequential_like = False
1485
- break
1486
- else:
1487
- flag = True
1488
- if not sequential_like:
1489
- break
1490
- return sequential_like
1491
-
1492
-
1493
- class Node(WBValue):
1494
- """Node used in `Graph`."""
1495
-
1496
- def __init__(
1497
- self,
1498
- id=None,
1499
- name=None,
1500
- class_name=None,
1501
- size=None,
1502
- parameters=None,
1503
- output_shape=None,
1504
- is_output=None,
1505
- num_parameters=None,
1506
- node=None,
1507
- ):
1508
- self._attributes = {"name": None}
1509
- self.in_edges = {} # indexed by source node id
1510
- self.out_edges = {} # indexed by dest node id
1511
- # optional object (e.g. PyTorch Parameter or Module) that this Node represents
1512
- self.obj = None
1513
-
1514
- if node is not None:
1515
- self._attributes.update(node._attributes)
1516
- del self._attributes["id"]
1517
- self.obj = node.obj
1518
-
1519
- if id is not None:
1520
- self.id = id
1521
- if name is not None:
1522
- self.name = name
1523
- if class_name is not None:
1524
- self.class_name = class_name
1525
- if size is not None:
1526
- self.size = size
1527
- if parameters is not None:
1528
- self.parameters = parameters
1529
- if output_shape is not None:
1530
- self.output_shape = output_shape
1531
- if is_output is not None:
1532
- self.is_output = is_output
1533
- if num_parameters is not None:
1534
- self.num_parameters = num_parameters
1535
-
1536
- def to_json(self, run=None):
1537
- return self._attributes
1538
-
1539
- def __repr__(self):
1540
- return repr(self._attributes)
1541
-
1542
- @property
1543
- def id(self):
1544
- """Must be unique in the graph."""
1545
- return self._attributes.get("id")
1546
-
1547
- @id.setter
1548
- def id(self, val):
1549
- self._attributes["id"] = val
1550
- return val
1551
-
1552
- @property
1553
- def name(self):
1554
- """Usually the type of layer or sublayer."""
1555
- return self._attributes.get("name")
1556
-
1557
- @name.setter
1558
- def name(self, val):
1559
- self._attributes["name"] = val
1560
- return val
1561
-
1562
- @property
1563
- def class_name(self):
1564
- """Usually the type of layer or sublayer."""
1565
- return self._attributes.get("class_name")
1566
-
1567
- @class_name.setter
1568
- def class_name(self, val):
1569
- self._attributes["class_name"] = val
1570
- return val
1571
-
1572
- @property
1573
- def functions(self):
1574
- return self._attributes.get("functions", [])
1575
-
1576
- @functions.setter
1577
- def functions(self, val):
1578
- self._attributes["functions"] = val
1579
- return val
1580
-
1581
- @property
1582
- def parameters(self):
1583
- return self._attributes.get("parameters", [])
1584
-
1585
- @parameters.setter
1586
- def parameters(self, val):
1587
- self._attributes["parameters"] = val
1588
- return val
1589
-
1590
- @property
1591
- def size(self):
1592
- return self._attributes.get("size")
1593
-
1594
- @size.setter
1595
- def size(self, val):
1596
- """Tensor size."""
1597
- self._attributes["size"] = tuple(val)
1598
- return val
1599
-
1600
- @property
1601
- def output_shape(self):
1602
- return self._attributes.get("output_shape")
1603
-
1604
- @output_shape.setter
1605
- def output_shape(self, val):
1606
- """Tensor output_shape."""
1607
- self._attributes["output_shape"] = val
1608
- return val
1609
-
1610
- @property
1611
- def is_output(self):
1612
- return self._attributes.get("is_output")
1613
-
1614
- @is_output.setter
1615
- def is_output(self, val):
1616
- """Tensor is_output."""
1617
- self._attributes["is_output"] = val
1618
- return val
1619
-
1620
- @property
1621
- def num_parameters(self):
1622
- return self._attributes.get("num_parameters")
1623
-
1624
- @num_parameters.setter
1625
- def num_parameters(self, val):
1626
- """Tensor num_parameters."""
1627
- self._attributes["num_parameters"] = val
1628
- return val
1629
-
1630
- @property
1631
- def child_parameters(self):
1632
- return self._attributes.get("child_parameters")
1633
-
1634
- @child_parameters.setter
1635
- def child_parameters(self, val):
1636
- """Tensor child_parameters."""
1637
- self._attributes["child_parameters"] = val
1638
- return val
1639
-
1640
- @property
1641
- def is_constant(self):
1642
- return self._attributes.get("is_constant")
1643
-
1644
- @is_constant.setter
1645
- def is_constant(self, val):
1646
- """Tensor is_constant."""
1647
- self._attributes["is_constant"] = val
1648
- return val
1649
-
1650
- @classmethod
1651
- def from_keras(cls, layer):
1652
- node = cls()
1653
-
1654
- try:
1655
- output_shape = layer.output_shape
1656
- except AttributeError:
1657
- output_shape = ["multiple"]
1658
-
1659
- node.id = layer.name
1660
- node.name = layer.name
1661
- node.class_name = layer.__class__.__name__
1662
- node.output_shape = output_shape
1663
- node.num_parameters = layer.count_params()
1664
-
1665
- return node
1666
-
1667
-
1668
- class Edge(WBValue):
1669
- """Edge used in `Graph`."""
1670
-
1671
- def __init__(self, from_node, to_node):
1672
- self._attributes = {}
1673
- self.from_node = from_node
1674
- self.to_node = to_node
1675
-
1676
- def __repr__(self):
1677
- temp_attr = dict(self._attributes)
1678
- del temp_attr["from_node"]
1679
- del temp_attr["to_node"]
1680
- temp_attr["from_id"] = self.from_node.id
1681
- temp_attr["to_id"] = self.to_node.id
1682
- return str(temp_attr)
1683
-
1684
- def to_json(self, run=None):
1685
- return [self.from_node.id, self.to_node.id]
1686
-
1687
- @property
1688
- def name(self):
1689
- """Optional, not necessarily unique."""
1690
- return self._attributes.get("name")
1691
-
1692
- @name.setter
1693
- def name(self, val):
1694
- self._attributes["name"] = val
1695
- return val
1696
-
1697
- @property
1698
- def from_node(self):
1699
- return self._attributes.get("from_node")
1700
-
1701
- @from_node.setter
1702
- def from_node(self, val):
1703
- self._attributes["from_node"] = val
1704
- return val
1705
-
1706
- @property
1707
- def to_node(self):
1708
- return self._attributes.get("to_node")
1709
-
1710
- @to_node.setter
1711
- def to_node(self, val):
1712
- self._attributes["to_node"] = val
1713
- return val
1714
-
1715
-
1716
- # Custom dtypes for typing system
1717
- class _ImageFileType(_dtypes.Type):
1718
- name = "image-file"
1719
- legacy_names = ["wandb.Image"]
1720
- types = [Image]
1721
-
1722
- def __init__(
1723
- self,
1724
- box_layers=None,
1725
- box_score_keys=None,
1726
- mask_layers=None,
1727
- class_map=None,
1728
- **kwargs,
1729
- ):
1730
- box_layers = box_layers or {}
1731
- box_score_keys = box_score_keys or []
1732
- mask_layers = mask_layers or {}
1733
- class_map = class_map or {}
1734
-
1735
- if isinstance(box_layers, _dtypes.ConstType):
1736
- box_layers = box_layers._params["val"]
1737
- if not isinstance(box_layers, dict):
1738
- raise TypeError("box_layers must be a dict")
1739
- else:
1740
- box_layers = _dtypes.ConstType(
1741
- {layer_key: set(box_layers[layer_key]) for layer_key in box_layers}
1742
- )
1743
-
1744
- if isinstance(mask_layers, _dtypes.ConstType):
1745
- mask_layers = mask_layers._params["val"]
1746
- if not isinstance(mask_layers, dict):
1747
- raise TypeError("mask_layers must be a dict")
1748
- else:
1749
- mask_layers = _dtypes.ConstType(
1750
- {layer_key: set(mask_layers[layer_key]) for layer_key in mask_layers}
1751
- )
1752
-
1753
- if isinstance(box_score_keys, _dtypes.ConstType):
1754
- box_score_keys = box_score_keys._params["val"]
1755
- if not isinstance(box_score_keys, list) and not isinstance(box_score_keys, set):
1756
- raise TypeError("box_score_keys must be a list or a set")
1757
- else:
1758
- box_score_keys = _dtypes.ConstType(set(box_score_keys))
1759
-
1760
- if isinstance(class_map, _dtypes.ConstType):
1761
- class_map = class_map._params["val"]
1762
- if not isinstance(class_map, dict):
1763
- raise TypeError("class_map must be a dict")
1764
- else:
1765
- class_map = _dtypes.ConstType(class_map)
1766
-
1767
- self.params.update(
1768
- {
1769
- "box_layers": box_layers,
1770
- "box_score_keys": box_score_keys,
1771
- "mask_layers": mask_layers,
1772
- "class_map": class_map,
1773
- }
1774
- )
1775
-
1776
- def assign_type(self, wb_type=None):
1777
- if isinstance(wb_type, _ImageFileType):
1778
- box_layers_self = self.params["box_layers"].params["val"] or {}
1779
- box_score_keys_self = self.params["box_score_keys"].params["val"] or []
1780
- mask_layers_self = self.params["mask_layers"].params["val"] or {}
1781
- class_map_self = self.params["class_map"].params["val"] or {}
1782
-
1783
- box_layers_other = wb_type.params["box_layers"].params["val"] or {}
1784
- box_score_keys_other = wb_type.params["box_score_keys"].params["val"] or []
1785
- mask_layers_other = wb_type.params["mask_layers"].params["val"] or {}
1786
- class_map_other = wb_type.params["class_map"].params["val"] or {}
1787
-
1788
- # Merge the class_ids from each set of box_layers
1789
- box_layers = {
1790
- str(key): set(
1791
- list(box_layers_self.get(key, []))
1792
- + list(box_layers_other.get(key, []))
1793
- )
1794
- for key in set(
1795
- list(box_layers_self.keys()) + list(box_layers_other.keys())
1796
- )
1797
- }
1798
-
1799
- # Merge the class_ids from each set of mask_layers
1800
- mask_layers = {
1801
- str(key): set(
1802
- list(mask_layers_self.get(key, []))
1803
- + list(mask_layers_other.get(key, []))
1804
- )
1805
- for key in set(
1806
- list(mask_layers_self.keys()) + list(mask_layers_other.keys())
1807
- )
1808
- }
1809
-
1810
- # Merge the box score keys
1811
- box_score_keys = set(list(box_score_keys_self) + list(box_score_keys_other))
1812
-
1813
- # Merge the class_map
1814
- class_map = {
1815
- str(key): class_map_self.get(key, class_map_other.get(key, None))
1816
- for key in set(
1817
- list(class_map_self.keys()) + list(class_map_other.keys())
1818
- )
1819
- }
1820
-
1821
- return _ImageFileType(box_layers, box_score_keys, mask_layers, class_map)
1822
-
1823
- return _dtypes.InvalidType()
1824
-
1825
- @classmethod
1826
- def from_obj(cls, py_obj):
1827
- if not isinstance(py_obj, Image):
1828
- raise TypeError("py_obj must be a wandb.Image")
1829
- else:
1830
- if hasattr(py_obj, "_boxes") and py_obj._boxes:
1831
- box_layers = {
1832
- str(key): set(py_obj._boxes[key]._class_labels.keys())
1833
- for key in py_obj._boxes.keys()
1834
- }
1835
- box_score_keys = {
1836
- key
1837
- for val in py_obj._boxes.values()
1838
- for box in val._val
1839
- for key in box.get("scores", {}).keys()
1840
- }
1841
-
1842
- else:
1843
- box_layers = {}
1844
- box_score_keys = set()
1845
-
1846
- if hasattr(py_obj, "_masks") and py_obj._masks:
1847
- mask_layers = {
1848
- str(key): set(
1849
- py_obj._masks[key]._val["class_labels"].keys()
1850
- if hasattr(py_obj._masks[key], "_val")
1851
- else []
1852
- )
1853
- for key in py_obj._masks.keys()
1854
- }
1855
- else:
1856
- mask_layers = {}
1857
-
1858
- if hasattr(py_obj, "_classes") and py_obj._classes:
1859
- class_set = {
1860
- str(item["id"]): item["name"] for item in py_obj._classes._class_set
1861
- }
1862
- else:
1863
- class_set = {}
1864
-
1865
- return cls(box_layers, box_score_keys, mask_layers, class_set)
1866
-
1867
-
1868
- class _TableType(_dtypes.Type):
1869
- name = "table"
1870
- legacy_names = ["wandb.Table"]
1871
- types = [Table]
1872
-
1873
- def __init__(self, column_types=None):
1874
- if column_types is None:
1875
- column_types = _dtypes.UnknownType()
1876
- if isinstance(column_types, dict):
1877
- column_types = _dtypes.TypedDictType(column_types)
1878
- elif not (
1879
- isinstance(column_types, _dtypes.TypedDictType)
1880
- or isinstance(column_types, _dtypes.UnknownType)
1881
- ):
1882
- raise TypeError("column_types must be a dict or TypedDictType")
1883
-
1884
- self.params.update({"column_types": column_types})
1885
-
1886
- def assign_type(self, wb_type=None):
1887
- if isinstance(wb_type, _TableType):
1888
- column_types = self.params["column_types"].assign_type(
1889
- wb_type.params["column_types"]
1890
- )
1891
- if not isinstance(column_types, _dtypes.InvalidType):
1892
- return _TableType(column_types)
1893
-
1894
- return _dtypes.InvalidType()
1895
-
1896
- @classmethod
1897
- def from_obj(cls, py_obj):
1898
- if not isinstance(py_obj, Table):
1899
- raise TypeError("py_obj must be a wandb.Table")
1900
- else:
1901
- return cls(py_obj._column_types)
1902
-
1903
-
1904
- class _ForeignKeyType(_dtypes.Type):
1905
- name = "foreignKey"
1906
- legacy_names = ["wandb.TableForeignKey"]
1907
- types = [_TableKey]
1908
-
1909
- def __init__(self, table, col_name):
1910
- assert isinstance(table, Table)
1911
- assert isinstance(col_name, str)
1912
- assert col_name in table.columns
1913
- self.params.update({"table": table, "col_name": col_name})
1914
-
1915
- def assign_type(self, wb_type=None):
1916
- if isinstance(wb_type, _dtypes.StringType):
1917
- return self
1918
- elif (
1919
- isinstance(wb_type, _ForeignKeyType)
1920
- and id(self.params["table"]) == id(wb_type.params["table"])
1921
- and self.params["col_name"] == wb_type.params["col_name"]
1922
- ):
1923
- return self
1924
-
1925
- return _dtypes.InvalidType()
1926
-
1927
- @classmethod
1928
- def from_obj(cls, py_obj):
1929
- if not isinstance(py_obj, _TableKey):
1930
- raise TypeError("py_obj must be a _TableKey")
1931
- else:
1932
- return cls(py_obj._table, py_obj._col_name)
1933
-
1934
- def to_json(self, artifact=None):
1935
- res = super().to_json(artifact)
1936
- if artifact is not None:
1937
- table_name = f"media/tables/t_{runid.generate_id()}"
1938
- entry = artifact.add(self.params["table"], table_name)
1939
- res["params"]["table"] = entry.path
1940
- else:
1941
- raise AssertionError(
1942
- "_ForeignKeyType does not support serialization without an artifact"
1943
- )
1944
- return res
1945
-
1946
- @classmethod
1947
- def from_json(
1948
- cls,
1949
- json_dict,
1950
- artifact,
1951
- ):
1952
- table = None
1953
- col_name = None
1954
- if artifact is None:
1955
- raise AssertionError(
1956
- "_ForeignKeyType does not support deserialization without an artifact"
1957
- )
1958
- else:
1959
- table = artifact.get(json_dict["params"]["table"])
1960
- col_name = json_dict["params"]["col_name"]
1961
-
1962
- if table is None:
1963
- raise AssertionError("Unable to deserialize referenced table")
1964
-
1965
- return cls(table, col_name)
1966
-
1967
-
1968
- class _ForeignIndexType(_dtypes.Type):
1969
- name = "foreignIndex"
1970
- legacy_names = ["wandb.TableForeignIndex"]
1971
- types = [_TableIndex]
1972
-
1973
- def __init__(self, table):
1974
- assert isinstance(table, Table)
1975
- self.params.update({"table": table})
1976
-
1977
- def assign_type(self, wb_type=None):
1978
- if isinstance(wb_type, _dtypes.NumberType):
1979
- return self
1980
- elif isinstance(wb_type, _ForeignIndexType) and id(self.params["table"]) == id(
1981
- wb_type.params["table"]
1982
- ):
1983
- return self
1984
-
1985
- return _dtypes.InvalidType()
1986
-
1987
- @classmethod
1988
- def from_obj(cls, py_obj):
1989
- if not isinstance(py_obj, _TableIndex):
1990
- raise TypeError("py_obj must be a _TableIndex")
1991
- else:
1992
- return cls(py_obj._table)
1993
-
1994
- def to_json(self, artifact=None):
1995
- res = super().to_json(artifact)
1996
- if artifact is not None:
1997
- table_name = f"media/tables/t_{runid.generate_id()}"
1998
- entry = artifact.add(self.params["table"], table_name)
1999
- res["params"]["table"] = entry.path
2000
- else:
2001
- raise AssertionError(
2002
- "_ForeignIndexType does not support serialization without an artifact"
2003
- )
2004
- return res
2005
-
2006
- @classmethod
2007
- def from_json(
2008
- cls,
2009
- json_dict,
2010
- artifact,
2011
- ):
2012
- table = None
2013
- if artifact is None:
2014
- raise AssertionError(
2015
- "_ForeignIndexType does not support deserialization without an artifact"
2016
- )
2017
- else:
2018
- table = artifact.get(json_dict["params"]["table"])
2019
-
2020
- if table is None:
2021
- raise AssertionError("Unable to deserialize referenced table")
2022
-
2023
- return cls(table)
2024
-
2025
-
2026
- class _PrimaryKeyType(_dtypes.Type):
2027
- name = "primaryKey"
2028
- legacy_names = ["wandb.TablePrimaryKey"]
2029
-
2030
- def assign_type(self, wb_type=None):
2031
- if isinstance(wb_type, _dtypes.StringType) or isinstance(
2032
- wb_type, _PrimaryKeyType
2033
- ):
2034
- return self
2035
- return _dtypes.InvalidType()
2036
-
2037
- @classmethod
2038
- def from_obj(cls, py_obj):
2039
- if not isinstance(py_obj, _TableKey):
2040
- raise TypeError("py_obj must be a wandb.Table")
2041
- else:
2042
- return cls()
2043
-
2044
-
2045
- class _AudioFileType(_dtypes.Type):
2046
- name = "audio-file"
2047
- types = [Audio]
2048
-
2049
-
2050
- class _BokehFileType(_dtypes.Type):
2051
- name = "bokeh-file"
2052
- types = [Bokeh]
2053
-
2054
-
2055
- class _JoinedTableType(_dtypes.Type):
2056
- name = "joined-table"
2057
- types = [JoinedTable]
2058
-
2059
-
2060
- class _PartitionedTableType(_dtypes.Type):
2061
- name = "partitioned-table"
2062
- types = [PartitionedTable]
2063
-
2064
-
2065
- _dtypes.TypeRegistry.add(_AudioFileType)
2066
- _dtypes.TypeRegistry.add(_BokehFileType)
2067
- _dtypes.TypeRegistry.add(_ImageFileType)
2068
- _dtypes.TypeRegistry.add(_TableType)
2069
- _dtypes.TypeRegistry.add(_JoinedTableType)
2070
- _dtypes.TypeRegistry.add(_PartitionedTableType)
2071
- _dtypes.TypeRegistry.add(_ForeignKeyType)
2072
- _dtypes.TypeRegistry.add(_PrimaryKeyType)
2073
- _dtypes.TypeRegistry.add(_ForeignIndexType)