wandb 0.17.0rc2__py3-none-win32.whl → 0.17.2__py3-none-win32.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (164) hide show
  1. wandb/__init__.py +4 -2
  2. wandb/apis/importers/internals/internal.py +0 -1
  3. wandb/apis/importers/wandb.py +12 -7
  4. wandb/apis/internal.py +0 -3
  5. wandb/apis/public/api.py +213 -79
  6. wandb/apis/public/artifacts.py +335 -100
  7. wandb/apis/public/files.py +9 -9
  8. wandb/apis/public/jobs.py +16 -4
  9. wandb/apis/public/projects.py +26 -28
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +163 -65
  12. wandb/apis/public/sweeps.py +2 -2
  13. wandb/apis/reports/__init__.py +1 -7
  14. wandb/apis/reports/v1/__init__.py +5 -27
  15. wandb/apis/reports/v2/__init__.py +7 -19
  16. wandb/apis/workspaces/__init__.py +8 -0
  17. wandb/beta/workflows.py +8 -3
  18. wandb/bin/wandb-core +0 -0
  19. wandb/cli/cli.py +151 -59
  20. wandb/docker/__init__.py +1 -1
  21. wandb/errors/term.py +10 -2
  22. wandb/filesync/step_checksum.py +1 -4
  23. wandb/filesync/step_prepare.py +4 -24
  24. wandb/filesync/step_upload.py +5 -107
  25. wandb/filesync/upload_job.py +0 -76
  26. wandb/integration/gym/__init__.py +35 -15
  27. wandb/integration/openai/fine_tuning.py +21 -3
  28. wandb/integration/prodigy/prodigy.py +1 -1
  29. wandb/jupyter.py +16 -17
  30. wandb/old/summary.py +5 -0
  31. wandb/plot/pr_curve.py +2 -1
  32. wandb/plot/roc_curve.py +2 -1
  33. wandb/{plots → plot}/utils.py +13 -25
  34. wandb/proto/v3/wandb_internal_pb2.py +54 -54
  35. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  36. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  37. wandb/proto/v4/wandb_internal_pb2.py +54 -54
  38. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  39. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  40. wandb/proto/v5/wandb_base_pb2.py +30 -0
  41. wandb/proto/v5/wandb_internal_pb2.py +355 -0
  42. wandb/proto/v5/wandb_server_pb2.py +63 -0
  43. wandb/proto/v5/wandb_settings_pb2.py +45 -0
  44. wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
  45. wandb/proto/wandb_base_pb2.py +2 -0
  46. wandb/proto/wandb_deprecated.py +9 -1
  47. wandb/proto/wandb_generate_deprecated.py +34 -0
  48. wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
  49. wandb/proto/wandb_internal_pb2.py +2 -0
  50. wandb/proto/wandb_server_pb2.py +2 -0
  51. wandb/proto/wandb_settings_pb2.py +2 -0
  52. wandb/proto/wandb_telemetry_pb2.py +2 -0
  53. wandb/sdk/artifacts/artifact.py +76 -23
  54. wandb/sdk/artifacts/artifact_manifest.py +1 -1
  55. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
  56. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
  57. wandb/sdk/artifacts/artifact_saver.py +1 -10
  58. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
  59. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  60. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
  62. wandb/sdk/artifacts/storage_policy.py +1 -12
  63. wandb/sdk/data_types/_dtypes.py +5 -2
  64. wandb/sdk/data_types/html.py +1 -1
  65. wandb/sdk/data_types/image.py +1 -1
  66. wandb/sdk/data_types/object_3d.py +1 -1
  67. wandb/sdk/data_types/video.py +4 -2
  68. wandb/sdk/interface/interface.py +13 -0
  69. wandb/sdk/interface/interface_shared.py +1 -1
  70. wandb/sdk/internal/file_pusher.py +2 -5
  71. wandb/sdk/internal/file_stream.py +6 -19
  72. wandb/sdk/internal/internal_api.py +160 -138
  73. wandb/sdk/internal/job_builder.py +207 -135
  74. wandb/sdk/internal/progress.py +0 -28
  75. wandb/sdk/internal/sender.py +105 -42
  76. wandb/sdk/internal/settings_static.py +8 -1
  77. wandb/sdk/internal/system/assets/gpu.py +2 -0
  78. wandb/sdk/internal/system/assets/trainium.py +3 -3
  79. wandb/sdk/internal/system/system_info.py +4 -2
  80. wandb/sdk/internal/update.py +1 -1
  81. wandb/sdk/launch/__init__.py +9 -1
  82. wandb/sdk/launch/_launch.py +4 -24
  83. wandb/sdk/launch/_launch_add.py +1 -3
  84. wandb/sdk/launch/_project_spec.py +184 -224
  85. wandb/sdk/launch/agent/agent.py +58 -18
  86. wandb/sdk/launch/agent/config.py +0 -3
  87. wandb/sdk/launch/builder/abstract.py +67 -0
  88. wandb/sdk/launch/builder/build.py +165 -576
  89. wandb/sdk/launch/builder/context_manager.py +235 -0
  90. wandb/sdk/launch/builder/docker_builder.py +7 -23
  91. wandb/sdk/launch/builder/kaniko_builder.py +10 -23
  92. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  93. wandb/sdk/launch/create_job.py +51 -45
  94. wandb/sdk/launch/environment/aws_environment.py +26 -1
  95. wandb/sdk/launch/inputs/files.py +148 -0
  96. wandb/sdk/launch/inputs/internal.py +224 -0
  97. wandb/sdk/launch/inputs/manage.py +95 -0
  98. wandb/sdk/launch/runner/abstract.py +2 -2
  99. wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
  100. wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
  101. wandb/sdk/launch/runner/local_container.py +2 -3
  102. wandb/sdk/launch/runner/local_process.py +8 -29
  103. wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
  104. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  105. wandb/sdk/launch/sweeps/scheduler.py +2 -0
  106. wandb/sdk/launch/sweeps/utils.py +2 -2
  107. wandb/sdk/launch/utils.py +16 -138
  108. wandb/sdk/lib/_settings_toposort_generated.py +2 -5
  109. wandb/sdk/lib/apikey.py +4 -2
  110. wandb/sdk/lib/config_util.py +3 -3
  111. wandb/sdk/lib/proto_util.py +22 -1
  112. wandb/sdk/lib/redirect.py +1 -1
  113. wandb/sdk/service/service.py +2 -1
  114. wandb/sdk/service/streams.py +5 -5
  115. wandb/sdk/wandb_init.py +25 -59
  116. wandb/sdk/wandb_login.py +28 -25
  117. wandb/sdk/wandb_run.py +135 -70
  118. wandb/sdk/wandb_settings.py +33 -64
  119. wandb/sdk/wandb_watch.py +1 -1
  120. wandb/sklearn/plot/classifier.py +4 -6
  121. wandb/sync/sync.py +2 -2
  122. wandb/testing/relay.py +32 -17
  123. wandb/util.py +39 -37
  124. wandb/wandb_agent.py +3 -3
  125. wandb/wandb_controller.py +3 -2
  126. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/METADATA +7 -9
  127. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/RECORD +130 -152
  128. wandb/apis/reports/v1/_blocks.py +0 -1406
  129. wandb/apis/reports/v1/_helpers.py +0 -70
  130. wandb/apis/reports/v1/_panels.py +0 -1282
  131. wandb/apis/reports/v1/_templates.py +0 -478
  132. wandb/apis/reports/v1/blocks.py +0 -27
  133. wandb/apis/reports/v1/helpers.py +0 -2
  134. wandb/apis/reports/v1/mutations.py +0 -66
  135. wandb/apis/reports/v1/panels.py +0 -17
  136. wandb/apis/reports/v1/report.py +0 -268
  137. wandb/apis/reports/v1/runset.py +0 -144
  138. wandb/apis/reports/v1/templates.py +0 -7
  139. wandb/apis/reports/v1/util.py +0 -406
  140. wandb/apis/reports/v1/validators.py +0 -131
  141. wandb/apis/reports/v2/blocks.py +0 -25
  142. wandb/apis/reports/v2/expr_parsing.py +0 -257
  143. wandb/apis/reports/v2/gql.py +0 -68
  144. wandb/apis/reports/v2/interface.py +0 -1911
  145. wandb/apis/reports/v2/internal.py +0 -867
  146. wandb/apis/reports/v2/metrics.py +0 -6
  147. wandb/apis/reports/v2/panels.py +0 -15
  148. wandb/catboost/__init__.py +0 -9
  149. wandb/fastai/__init__.py +0 -9
  150. wandb/keras/__init__.py +0 -19
  151. wandb/lightgbm/__init__.py +0 -9
  152. wandb/plots/__init__.py +0 -6
  153. wandb/plots/explain_text.py +0 -36
  154. wandb/plots/heatmap.py +0 -81
  155. wandb/plots/named_entity.py +0 -43
  156. wandb/plots/part_of_speech.py +0 -50
  157. wandb/plots/plot_definitions.py +0 -768
  158. wandb/plots/precision_recall.py +0 -121
  159. wandb/plots/roc.py +0 -103
  160. wandb/sacred/__init__.py +0 -3
  161. wandb/xgboost/__init__.py +0 -9
  162. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/WHEEL +0 -0
  163. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/entry_points.txt +0 -0
  164. {wandb-0.17.0rc2.dist-info → wandb-0.17.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,9 @@
1
1
  """Public API: artifacts."""
2
2
 
3
3
  import json
4
- from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
4
+ import re
5
+ from copy import copy
6
+ from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence
5
7
 
6
8
  from wandb_gql import Client, gql
7
9
 
@@ -10,6 +12,7 @@ from wandb.apis import public
10
12
  from wandb.apis.normalize import normalize_exceptions
11
13
  from wandb.apis.paginator import Paginator
12
14
  from wandb.errors.term import termlog
15
+ from wandb.sdk.lib import deprecate
13
16
 
14
17
  if TYPE_CHECKING:
15
18
  from wandb.apis.public import RetryingClient, Run
@@ -65,16 +68,15 @@ class ArtifactTypes(Paginator):
65
68
  $entityName: String!,
66
69
  $projectName: String!,
67
70
  $cursor: String,
68
- ) {
69
- project(name: $projectName, entityName: $entityName) {
70
- artifactTypes(after: $cursor) {
71
+ ) {{
72
+ project(name: $projectName, entityName: $entityName) {{
73
+ artifactTypes(after: $cursor) {{
71
74
  ...ArtifactTypesFragment
72
- }
73
- }
74
- }
75
- %s
76
- """
77
- % ARTIFACTS_TYPES_FRAGMENT
75
+ }}
76
+ }}
77
+ }}
78
+ {}
79
+ """.format(ARTIFACTS_TYPES_FRAGMENT)
78
80
  )
79
81
 
80
82
  def __init__(
@@ -178,7 +180,7 @@ class ArtifactType:
178
180
  or response.get("project") is None
179
181
  or response["project"].get("artifactType") is None
180
182
  ):
181
- raise ValueError("Could not find artifact type %s" % self.type)
183
+ raise ValueError("Could not find artifact type {}".format(self.type))
182
184
  self._attrs = response["project"]["artifactType"]
183
185
  return self._attrs
184
186
 
@@ -230,31 +232,32 @@ class ArtifactCollections(Paginator):
230
232
  $projectName: String!,
231
233
  $artifactTypeName: String!
232
234
  $cursor: String,
233
- ) {
234
- project(name: $projectName, entityName: $entityName) {
235
- artifactType(name: $artifactTypeName) {
236
- artifactCollections: %s(after: $cursor) {
237
- pageInfo {
235
+ ) {{
236
+ project(name: $projectName, entityName: $entityName) {{
237
+ artifactType(name: $artifactTypeName) {{
238
+ artifactCollections: {}(after: $cursor) {{
239
+ pageInfo {{
238
240
  endCursor
239
241
  hasNextPage
240
- }
242
+ }}
241
243
  totalCount
242
- edges {
243
- node {
244
+ edges {{
245
+ node {{
244
246
  id
245
247
  name
246
248
  description
247
249
  createdAt
248
- }
250
+ }}
249
251
  cursor
250
- }
251
- }
252
- }
253
- }
254
- }
255
- """
256
- % artifact_collection_plural_edge_name(
257
- server_supports_artifact_collections_gql_edges(client)
252
+ }}
253
+ }}
254
+ }}
255
+ }}
256
+ }}
257
+ """.format(
258
+ artifact_collection_plural_edge_name(
259
+ server_supports_artifact_collections_gql_edges(client)
260
+ )
258
261
  )
259
262
  )
260
263
 
@@ -318,12 +321,16 @@ class ArtifactCollection:
318
321
  self.client = client
319
322
  self.entity = entity
320
323
  self.project = project
321
- self.name = name
322
- self.type = type
324
+ self._name = name
325
+ self._saved_name = name
326
+ self._type = type
327
+ self._saved_type = type
323
328
  self._attrs = attrs
324
- if self._attrs is None:
325
- self.load()
329
+ self.load()
326
330
  self._aliases = [a["node"]["alias"] for a in self._attrs["aliases"]["edges"]]
331
+ self._description = self._attrs["description"]
332
+ self._tags = [a["node"]["name"] for a in self._attrs["tags"]["edges"]]
333
+ self._saved_tags = copy(self._tags)
327
334
 
328
335
  @property
329
336
  def id(self):
@@ -336,8 +343,8 @@ class ArtifactCollection:
336
343
  self.client,
337
344
  self.entity,
338
345
  self.project,
339
- self.name,
340
- self.type,
346
+ self._saved_name,
347
+ self._saved_type,
341
348
  per_page=per_page,
342
349
  )
343
350
 
@@ -356,33 +363,45 @@ class ArtifactCollection:
356
363
  $artifactCollectionName: String!,
357
364
  $cursor: String,
358
365
  $perPage: Int = 1000
359
- ) {
360
- project(name: $projectName, entityName: $entityName) {
361
- artifactType(name: $artifactTypeName) {
362
- artifactCollection: %s(name: $artifactCollectionName) {
366
+ ) {{
367
+ project(name: $projectName, entityName: $entityName) {{
368
+ artifactType(name: $artifactTypeName) {{
369
+ artifactCollection: {}(name: $artifactCollectionName) {{
363
370
  id
364
371
  name
365
372
  description
366
373
  createdAt
367
- aliases(after: $cursor, first: $perPage){
368
- edges {
369
- node {
374
+ tags {{
375
+ edges {{
376
+ node {{
377
+ id
378
+ name
379
+ }}
380
+ }}
381
+ }}
382
+ aliases(after: $cursor, first: $perPage){{
383
+ edges {{
384
+ node {{
370
385
  alias
371
- }
386
+ }}
372
387
  cursor
373
- }
374
- pageInfo {
388
+ }}
389
+ pageInfo {{
375
390
  endCursor
376
391
  hasNextPage
377
- }
378
- }
379
- }
380
- }
381
- }
382
- }
383
- """
384
- % artifact_collection_edge_name(
385
- server_supports_artifact_collections_gql_edges(self.client)
392
+ }}
393
+ }}
394
+ }}
395
+ artifactSequence(name: $artifactCollectionName) {{
396
+ __typename
397
+ }}
398
+ }}
399
+ }}
400
+ }}
401
+ """.format(
402
+ artifact_collection_edge_name(
403
+ server_supports_artifact_collections_gql_edges(self.client)
404
+ )
386
405
  )
387
406
  )
388
407
  response = self.client.execute(
@@ -390,8 +409,8 @@ class ArtifactCollection:
390
409
  variable_values={
391
410
  "entityName": self.entity,
392
411
  "projectName": self.project,
393
- "artifactTypeName": self.type,
394
- "artifactCollectionName": self.name,
412
+ "artifactTypeName": self._saved_type,
413
+ "artifactCollectionName": self._saved_name,
395
414
  },
396
415
  )
397
416
  if (
@@ -400,19 +419,28 @@ class ArtifactCollection:
400
419
  or response["project"].get("artifactType") is None
401
420
  or response["project"]["artifactType"].get("artifactCollection") is None
402
421
  ):
403
- raise ValueError("Could not find artifact type %s" % self.type)
404
- self._attrs = response["project"]["artifactType"]["artifactCollection"]
422
+ raise ValueError("Could not find artifact type {}".format(self._saved_type))
423
+ sequence = response["project"]["artifactType"]["artifactSequence"]
424
+ self._is_sequence = (
425
+ sequence is not None and sequence["__typename"] == "ArtifactSequence"
426
+ )
427
+
428
+ if self._attrs is None:
429
+ self._attrs = response["project"]["artifactType"]["artifactCollection"]
405
430
  return self._attrs
406
431
 
407
432
  def change_type(self, new_type: str) -> None:
408
- """Change the type of the artifact collection.
433
+ """Deprecated, change type directly with `save` instead."""
434
+ deprecate.deprecate(
435
+ field_name=deprecate.Deprecated.artifact_collection__change_type,
436
+ warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.",
437
+ )
409
438
 
410
- Arguments:
411
- new_type: The new collection type to use, freeform string.
412
- """
413
439
  if not self.is_sequence():
414
440
  raise ValueError("Artifact collection needs to be a sequence")
415
- termlog(f"Changing artifact collection type of " f"{self.type} to {new_type}")
441
+ termlog(
442
+ f"Changing artifact collection type of {self._saved_type} to {new_type}"
443
+ )
416
444
  template = """
417
445
  mutation MoveArtifactCollection(
418
446
  $artifactSequenceID: ID!
@@ -439,34 +467,12 @@ class ArtifactCollection:
439
467
  }
440
468
  mutation = gql(template)
441
469
  self.client.execute(mutation, variable_values=variable_values)
442
- self.type = new_type
470
+ self._saved_type = new_type
471
+ self._type = new_type
443
472
 
444
- @normalize_exceptions
445
473
  def is_sequence(self) -> bool:
446
- """Return True if this is a sequence."""
447
- query = gql(
448
- """
449
- query FindSequence($entity: String!, $project: String!, $collection: String!, $type: String!) {
450
- project(name: $project, entityName: $entity) {
451
- artifactType(name: $type) {
452
- __typename
453
- artifactSequence(name: $collection) {
454
- __typename
455
- }
456
- }
457
- }
458
- }
459
- """
460
- )
461
- variables = {
462
- "entity": self.entity,
463
- "project": self.project,
464
- "collection": self.name,
465
- "type": self.type,
466
- }
467
- res = self.client.execute(query, variable_values=variables)
468
- sequence = res["project"]["artifactType"]["artifactSequence"]
469
- return sequence is not None and sequence["__typename"] == "ArtifactSequence"
474
+ """Return whether the artifact collection is a sequence."""
475
+ return self._is_sequence
470
476
 
471
477
  @normalize_exceptions
472
478
  def delete(self):
@@ -501,8 +507,238 @@ class ArtifactCollection:
501
507
  )
502
508
  self.client.execute(mutation, variable_values={"id": self.id})
503
509
 
510
+ @property
511
+ def description(self):
512
+ """A description of the artifact collection."""
513
+ return self._description
514
+
515
+ @description.setter
516
+ def description(self, description: Optional[str]) -> None:
517
+ self._description = description
518
+
519
+ @property
520
+ def tags(self):
521
+ """The tags associated with the artifact collection."""
522
+ return self._tags
523
+
524
+ @tags.setter
525
+ def tags(self, tags: List[str]) -> None:
526
+ if any(not re.match(r"^[-\w]+([ ]+[-\w]+)*$", tag) for tag in tags):
527
+ raise ValueError(
528
+ "Tags must only contain alphanumeric characters or underscores separated by spaces or hyphens"
529
+ )
530
+ self._tags = tags
531
+
532
+ @property
533
+ def name(self):
534
+ """The name of the artifact collection."""
535
+ return self._name
536
+
537
+ @name.setter
538
+ def name(self, name: List[str]) -> None:
539
+ self._name = name
540
+
541
+ @property
542
+ def type(self):
543
+ """The type of the artifact collection."""
544
+ return self._type
545
+
546
+ @type.setter
547
+ def type(self, type: List[str]) -> None:
548
+ if not self.is_sequence():
549
+ raise ValueError(
550
+ "Type can only be changed if the artifact collection is a sequence."
551
+ )
552
+ self._type = type
553
+
554
+ def _update_collection(self):
555
+ mutation = gql("""
556
+ mutation UpdateArtifactCollection(
557
+ $artifactSequenceID: ID!
558
+ $name: String
559
+ $description: String
560
+ ) {
561
+ updateArtifactSequence(
562
+ input: {
563
+ artifactSequenceID: $artifactSequenceID
564
+ name: $name
565
+ description: $description
566
+ }
567
+ ) {
568
+ artifactCollection {
569
+ id
570
+ name
571
+ description
572
+ }
573
+ }
574
+ }
575
+ """)
576
+
577
+ variable_values = {
578
+ "artifactSequenceID": self.id,
579
+ "name": self._name,
580
+ "description": self.description,
581
+ }
582
+ self.client.execute(mutation, variable_values=variable_values)
583
+ self._saved_name = self._name
584
+
585
+ def _update_collection_type(self):
586
+ type_mutation = gql("""
587
+ mutation MoveArtifactCollection(
588
+ $artifactSequenceID: ID!
589
+ $destinationArtifactTypeName: String!
590
+ ) {
591
+ moveArtifactSequence(
592
+ input: {
593
+ artifactSequenceID: $artifactSequenceID
594
+ destinationArtifactTypeName: $destinationArtifactTypeName
595
+ }
596
+ ) {
597
+ artifactCollection {
598
+ id
599
+ name
600
+ description
601
+ __typename
602
+ }
603
+ }
604
+ }
605
+ """)
606
+
607
+ variable_values = {
608
+ "artifactSequenceID": self.id,
609
+ "destinationArtifactTypeName": self._type,
610
+ }
611
+ self.client.execute(type_mutation, variable_values=variable_values)
612
+ self._saved_type = self._type
613
+
614
+ def _update_portfolio(self):
615
+ mutation = gql("""
616
+ mutation UpdateArtifactPortfolio(
617
+ $artifactPortfolioID: ID!
618
+ $name: String
619
+ $description: String
620
+ ) {
621
+ updateArtifactPortfolio(
622
+ input: {
623
+ artifactPortfolioID: $artifactPortfolioID
624
+ name: $name
625
+ description: $description
626
+ }
627
+ ) {
628
+ artifactCollection {
629
+ id
630
+ name
631
+ description
632
+ }
633
+ }
634
+ }
635
+ """)
636
+ variable_values = {
637
+ "artifactPortfolioID": self.id,
638
+ "name": self._name,
639
+ "description": self.description,
640
+ }
641
+ self.client.execute(mutation, variable_values=variable_values)
642
+ self._saved_name = self._name
643
+
644
+ def _add_tags(self, tags_to_add):
645
+ add_mutation = gql(
646
+ """
647
+ mutation CreateArtifactCollectionTagAssignments(
648
+ $entityName: String!
649
+ $projectName: String!
650
+ $artifactCollectionName: String!
651
+ $tags: [TagInput!]!
652
+ ) {
653
+ createArtifactCollectionTagAssignments(
654
+ input: {
655
+ entityName: $entityName
656
+ projectName: $projectName
657
+ artifactCollectionName: $artifactCollectionName
658
+ tags: $tags
659
+ }
660
+ ) {
661
+ tags {
662
+ id
663
+ name
664
+ tagCategoryName
665
+ }
666
+ }
667
+ }
668
+ """
669
+ )
670
+ self.client.execute(
671
+ add_mutation,
672
+ variable_values={
673
+ "entityName": self.entity,
674
+ "projectName": self.project,
675
+ "artifactCollectionName": self._saved_name,
676
+ "tags": [
677
+ {
678
+ "tagName": tag,
679
+ }
680
+ for tag in tags_to_add
681
+ ],
682
+ },
683
+ )
684
+
685
+ def _delete_tags(self, tags_to_delete):
686
+ delete_mutation = gql(
687
+ """
688
+ mutation DeleteArtifactCollectionTagAssignments(
689
+ $entityName: String!
690
+ $projectName: String!
691
+ $artifactCollectionName: String!
692
+ $tags: [TagInput!]!
693
+ ) {
694
+ deleteArtifactCollectionTagAssignments(
695
+ input: {
696
+ entityName: $entityName
697
+ projectName: $projectName
698
+ artifactCollectionName: $artifactCollectionName
699
+ tags: $tags
700
+ }
701
+ ) {
702
+ success
703
+ }
704
+ }
705
+ """
706
+ )
707
+ self.client.execute(
708
+ delete_mutation,
709
+ variable_values={
710
+ "entityName": self.entity,
711
+ "projectName": self.project,
712
+ "artifactCollectionName": self._saved_name,
713
+ "tags": [
714
+ {
715
+ "tagName": tag,
716
+ }
717
+ for tag in tags_to_delete
718
+ ],
719
+ },
720
+ )
721
+
722
+ def save(self) -> None:
723
+ """Persist any changes made to the artifact collection."""
724
+ if self.is_sequence():
725
+ self._update_collection()
726
+
727
+ if self._saved_type != self._type:
728
+ self._update_collection_type()
729
+ else:
730
+ self._update_portfolio()
731
+
732
+ tags_to_add = set(self._tags) - set(self._saved_tags)
733
+ tags_to_delete = set(self._saved_tags) - set(self._tags)
734
+ if len(tags_to_add) > 0:
735
+ self._add_tags(tags_to_add)
736
+ if len(tags_to_delete) > 0:
737
+ self._delete_tags(tags_to_delete)
738
+ self._saved_tags = copy(self._tags)
739
+
504
740
  def __repr__(self):
505
- return f"<ArtifactCollection {self.name} ({self.type})>"
741
+ return f"<ArtifactCollection {self._name} ({self._type})>"
506
742
 
507
743
 
508
744
  class Artifacts(Paginator):
@@ -742,18 +978,17 @@ class ArtifactFiles(Paginator):
742
978
  $fileNames: [String!],
743
979
  $fileCursor: String,
744
980
  $fileLimit: Int = 50
745
- ) {
746
- project(name: $projectName, entityName: $entityName) {
747
- artifactType(name: $artifactTypeName) {
748
- artifact(name: $artifactName) {
981
+ ) {{
982
+ project(name: $projectName, entityName: $entityName) {{
983
+ artifactType(name: $artifactTypeName) {{
984
+ artifact(name: $artifactName) {{
749
985
  ...ArtifactFilesFragment
750
- }
751
- }
752
- }
753
- }
754
- %s
755
- """
756
- % ARTIFACT_FILES_FRAGMENT
986
+ }}
987
+ }}
988
+ }}
989
+ }}
990
+ {}
991
+ """.format(ARTIFACT_FILES_FRAGMENT)
757
992
  )
758
993
 
759
994
  def __init__(
@@ -46,17 +46,16 @@ class Files(Paginator):
46
46
  QUERY = gql(
47
47
  """
48
48
  query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String,
49
- $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {
50
- project(name: $project, entityName: $entity) {
51
- run(name: $name) {
49
+ $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false) {{
50
+ project(name: $project, entityName: $entity) {{
51
+ run(name: $name) {{
52
52
  fileCount
53
53
  ...RunFilesFragment
54
- }
55
- }
56
- }
57
- %s
58
- """
59
- % FILE_FRAGMENT
54
+ }}
55
+ }}
56
+ }}
57
+ {}
58
+ """.format(FILE_FRAGMENT)
60
59
  )
61
60
 
62
61
  def __init__(self, client, run, names=None, per_page=50, upload=False):
@@ -153,6 +152,7 @@ class File(Attrs):
153
152
  root (str): Local directory to save the file. Defaults to ".".
154
153
  exist_ok (boolean): If `True`, will not raise ValueError if file already
155
154
  exists and will not re-download unless replace=True. Defaults to `False`.
155
+ api (Api, optional): If given, the `Api` instance used to download the file.
156
156
 
157
157
  Raises:
158
158
  `ValueError` if file already exists, replace=False and exist_ok=False.
wandb/apis/public/jobs.py CHANGED
@@ -63,6 +63,8 @@ class Job:
63
63
  # only use notebook job if entrypoint not set and notebook is set
64
64
  self._notebook_job = source_info.get("notebook", False)
65
65
  self._entrypoint = source_info.get("entrypoint")
66
+ self._dockerfile = source_info.get("dockerfile")
67
+ self._build_context = source_info.get("build_context")
66
68
  self._args = source_info.get("args")
67
69
  self._partial = self._job_info.get("_partial", False)
68
70
  self._requirements_file = os.path.join(self._fpath, "requirements.frozen.txt")
@@ -106,7 +108,7 @@ class Job:
106
108
  )
107
109
  new_entrypoint = self._entrypoint
108
110
  new_entrypoint[-1] = new_fname
109
- launch_project.set_entry_point(new_entrypoint)
111
+ launch_project.set_job_entry_point(new_entrypoint)
110
112
 
111
113
  def _configure_launch_project_repo(self, launch_project):
112
114
  git_info = self._job_info.get("source", {}).get("git", {})
@@ -123,7 +125,12 @@ class Job:
123
125
  if self._notebook_job:
124
126
  self._configure_launch_project_notebook(launch_project)
125
127
  else:
126
- launch_project.set_entry_point(self._entrypoint)
128
+ launch_project.set_job_entry_point(self._entrypoint)
129
+
130
+ if self._dockerfile:
131
+ launch_project.set_job_dockerfile(self._dockerfile)
132
+ if self._build_context:
133
+ launch_project.set_job_build_context(self._build_context)
127
134
 
128
135
  def _configure_launch_project_artifact(self, launch_project):
129
136
  artifact_string = self._job_info.get("source", {}).get("artifact")
@@ -139,7 +146,12 @@ class Job:
139
146
  if self._notebook_job:
140
147
  self._configure_launch_project_notebook(launch_project)
141
148
  else:
142
- launch_project.set_entry_point(self._entrypoint)
149
+ launch_project.set_job_entry_point(self._entrypoint)
150
+
151
+ if self._dockerfile:
152
+ launch_project.set_job_dockerfile(self._dockerfile)
153
+ if self._build_context:
154
+ launch_project.set_job_build_context(self._build_context)
143
155
 
144
156
  def _configure_launch_project_container(self, launch_project):
145
157
  launch_project.docker_image = self._job_info.get("source", {}).get("image")
@@ -148,7 +160,7 @@ class Job:
148
160
  "Job had malformed source dictionary without an image key"
149
161
  )
150
162
  if self._entrypoint:
151
- launch_project.set_entry_point(self._entrypoint)
163
+ launch_project.set_job_entry_point(self._entrypoint)
152
164
 
153
165
  def set_entrypoint(self, entrypoint: List[str]):
154
166
  self._entrypoint = entrypoint