langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501030804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -105,8 +105,8 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
105
105
  # metrics as needed.
106
106
  experiment.run(root_dir, '20241031_1')
107
107
 
108
- # Refresh the previous run located in 'run_20241031_1'.
109
- experiment.run(root_dir, '20241031_1', refresh=True)
108
+ # Reprocess the previous run located in 'run_20241031_1'.
109
+ experiment.run(root_dir, '20241031_1', reprocess=True)
110
110
  ```
111
111
 
112
112
  # Experiment Registration and Lookup
@@ -380,7 +380,8 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
380
380
  filter: Callable[['Experiment'], bool] | None = None, # pylint: disable=redefined-builtin
381
381
  example_ids: list[int] | None = None,
382
382
  raise_if_has_error: bool = False,
383
- refresh: bool = False,
383
+ reprocess: bool | list[int] = False,
384
+ regenerate_example_html: bool | list[int] = False,
384
385
  process_timeout: int | None = None,
385
386
  use_cache: Literal['global', 'per_dataset', 'no'] = 'per_dataset',
386
387
  note: str | None = None,
@@ -391,22 +392,25 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
391
392
  """Runs the experiment.
392
393
 
393
394
  Examples:
394
- # Start a new run.
395
- experiment.run('new')
395
+ # Start a new run under root_dir.
396
+ experiment.run(root_dir, 'new')
396
397
 
397
398
  # Continue the latest experiment run.
398
- experiment.run('latest')
399
+ experiment.run(root_dir, 'latest')
399
400
 
400
401
  # Continue the latest experiment run or start a new run if it does not
401
402
  # exist.
402
- experiment.run()
403
+ experiment.run(root_dir)
403
404
 
404
- # Start a new run and warm start from a previous run under sub-dir
405
- # 'run_20241031_1'.
406
- experiment.run('new', warm_start_from='20241031_1')
405
+ # Start a new run and warm start from another run's directory
406
+ # '/path/to/another/run_20241031_1/'.
407
+ experiment.run(
408
+ root_dir, 'new',
409
+ warm_start_from='/path/to/another/run_20241031_1/'
410
+ )
407
411
 
408
- # Refresh previous run under sub-dir 'run_20241031_1'.
409
- experiment.run('20241031_1', refresh=True)
412
+ # Reprocess previous run under sub-dir 'run_20241031_1'.
413
+ experiment.run(root_dir, '20241031_1', reprocess=True)
410
414
 
411
415
  Args:
412
416
  root_dir: The root of the output directory of the experiment.
@@ -426,8 +430,16 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
426
430
  example_ids: The example IDs to run. If None, it will run all examples.
427
431
  raise_if_has_error: If True, it will raise an error if any example fails.
428
432
  Otherwise, it will continue and report the error in the output.
429
- refresh: Whether to refresh the experiment. If True, it will delete the
430
- data under the current experiment run directory and start a new run.
433
+ reprocess: A boolean or a list of example IDs. If boolean, it indicates
434
+ that whether all the examples to be evaluated will be reprocessed,
435
+ meaning that existing checkpoints will be ignored. If a list of
436
+ example IDs, it indicates that only the specified examples will be
437
+ reprocessed.
438
+ regenerate_example_html: A boolean or a list of example IDs. If boolean,
439
+ it indicates that whether all the examples to be evaluated will have
440
+ their HTML files regenerated. If a list of example IDs, it indicates
441
+ that only the specified examples will have their HTML files
442
+ regenerated.
431
443
  process_timeout: The timeout in seconds for each process. If None, it
432
444
  will use the default timeout for the runner.
433
445
  use_cache: Whether to use LLM cache for the experiment.
@@ -454,7 +466,8 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
454
466
  filter=filter,
455
467
  example_ids=example_ids,
456
468
  raise_if_has_error=raise_if_has_error,
457
- refresh=refresh,
469
+ reprocess=reprocess,
470
+ regenerate_example_html=regenerate_example_html,
458
471
  use_cache=use_cache,
459
472
  process_timeout=process_timeout,
460
473
  note=note,
@@ -815,11 +828,21 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
815
828
  'The user tags for the current run.'
816
829
  ] = []
817
830
 
818
- refresh: Annotated[
819
- bool,
831
+ reprocess: Annotated[
832
+ bool | list[int],
820
833
  (
821
- 'If True, it will delete the data under the current '
822
- 'run directory and start a new run.'
834
+ 'If True, it will reprocess all examples under the current '
835
+ 'run directory. If a list of integers, examples of the given IDS '
836
+ 'will be reprocessed.'
837
+ )
838
+ ] = False
839
+
840
+ regenerate_example_html: Annotated[
841
+ bool | list[int],
842
+ (
843
+ 'If True, it will regenerate the HTML files for previously processed '
844
+ 'examples. If a list of integers, the HTML files for the examples of '
845
+ 'the given IDs will be regenerated'
823
846
  )
824
847
  ] = False
825
848
 
@@ -873,6 +896,42 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
873
896
  """Returns the output path for the experiment."""
874
897
  return os.path.join(self.output_dir(experiment), relative_path)
875
898
 
899
+ def examples_to_evaluate(self, experiment: Experiment) -> set[int]:
900
+ """Returns the example IDs to evaluate."""
901
+ if not experiment.is_leaf:
902
+ return set()
903
+ return set(
904
+ self.example_ids if self.example_ids else
905
+ range(1, experiment.num_examples + 1)
906
+ )
907
+
908
+ def examples_to_reprocess(self, experiment: Experiment) -> set[int]:
909
+ """Returns the example IDs to reprocess per request."""
910
+ if not self.reprocess:
911
+ return set()
912
+ reprocess_ids = self.examples_to_evaluate(experiment)
913
+ if isinstance(self.reprocess, list):
914
+ reprocess_ids &= set(self.reprocess)
915
+ return reprocess_ids
916
+
917
+ def examples_to_load(self, experiment: Experiment) -> set[int]:
918
+ """Returns the example IDs to load from checkpoint files.."""
919
+ load_ids = self.examples_to_evaluate(experiment)
920
+ if isinstance(self.regenerate_example_html, list):
921
+ load_ids |= set(self.regenerate_example_html)
922
+ load_ids -= self.examples_to_reprocess(experiment)
923
+ return load_ids
924
+
925
+ def examples_to_load_metadata(self, experiment: Experiment) -> set[int]:
926
+ """Returns the example IDs to load the metadata."""
927
+ load_metadata_ids = set()
928
+ if isinstance(self.regenerate_example_html, list):
929
+ load_metadata_ids = set(self.regenerate_example_html)
930
+ elif self.regenerate_example_html:
931
+ load_metadata_ids = self.examples_to_evaluate(experiment)
932
+ load_metadata_ids -= self.examples_to_reprocess(experiment)
933
+ return load_metadata_ids
934
+
876
935
 
877
936
  class Runner(pg.Object):
878
937
  """Interface for experiment runner."""
@@ -31,10 +31,10 @@ Runner = experiment_lib.Runner
31
31
 
32
32
 
33
33
  @pg.functor()
34
- def sample_inputs():
34
+ def sample_inputs(num_examples: int = 1):
35
35
  return [
36
36
  pg.Dict(x=1)
37
- ]
37
+ ] * num_examples
38
38
 
39
39
 
40
40
  class MyEvaluation(Evaluation):
@@ -208,7 +208,7 @@ class RunIdTest(unittest.TestCase):
208
208
 
209
209
  class RunTest(unittest.TestCase):
210
210
 
211
- def test_basic(self):
211
+ def test_input_output_paths(self):
212
212
  run = Run(
213
213
  '/root',
214
214
  RunId.from_id('20241102_0'),
@@ -270,6 +270,107 @@ class RunTest(unittest.TestCase):
270
270
  )
271
271
  )
272
272
 
273
+ def test_examples_start_from_scratch(self):
274
+ run = Run(
275
+ '/root',
276
+ RunId.from_id('20241102_0'),
277
+ pg.Ref(Suite([
278
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
279
+ ])),
280
+ )
281
+ root = run.experiment
282
+ self.assertEqual(run.examples_to_evaluate(root), set())
283
+ self.assertEqual(run.examples_to_reprocess(root), set())
284
+ self.assertEqual(run.examples_to_load(root), set())
285
+ self.assertEqual(run.examples_to_load_metadata(root), set())
286
+
287
+ exp = root.leaf_nodes[0]
288
+ self.assertEqual(run.examples_to_evaluate(exp), set(range(1, 11)))
289
+ self.assertEqual(run.examples_to_reprocess(exp), set())
290
+ self.assertEqual(run.examples_to_load(exp), set(range(1, 11)))
291
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
292
+
293
+ def test_examples_with_example_ids(self):
294
+ run = Run(
295
+ '/root',
296
+ RunId.from_id('20241102_0'),
297
+ pg.Ref(Suite([
298
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
299
+ ])),
300
+ example_ids=[1, 3, 5]
301
+ )
302
+ exp = run.experiment.leaf_nodes[0]
303
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
304
+ self.assertEqual(run.examples_to_reprocess(exp), set())
305
+ self.assertEqual(run.examples_to_load(exp), set([1, 3, 5]))
306
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
307
+
308
+ def test_examples_with_reprocess_all(self):
309
+ run = Run(
310
+ '/root',
311
+ RunId.from_id('20241102_0'),
312
+ pg.Ref(Suite([
313
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
314
+ ])),
315
+ example_ids=[1, 3, 5],
316
+ reprocess=True
317
+ )
318
+ exp = run.experiment.leaf_nodes[0]
319
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
320
+ self.assertEqual(run.examples_to_reprocess(exp), set([1, 3, 5]))
321
+ self.assertEqual(run.examples_to_load(exp), set())
322
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
323
+
324
+ def test_examples_with_reprocess_some(self):
325
+ run = Run(
326
+ '/root',
327
+ RunId.from_id('20241102_0'),
328
+ pg.Ref(Suite([
329
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
330
+ ])),
331
+ example_ids=[1, 3, 5],
332
+ reprocess=[1],
333
+ )
334
+ exp = run.experiment.leaf_nodes[0]
335
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
336
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
337
+ self.assertEqual(run.examples_to_load(exp), set([3, 5]))
338
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
339
+
340
+ def test_examples_with_regenerate_example_html_all(self):
341
+ run = Run(
342
+ '/root',
343
+ RunId.from_id('20241102_0'),
344
+ pg.Ref(Suite([
345
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
346
+ ])),
347
+ example_ids=[1, 3, 5],
348
+ reprocess=[1],
349
+ regenerate_example_html=True,
350
+ )
351
+ exp = run.experiment.leaf_nodes[0]
352
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
353
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
354
+ self.assertEqual(run.examples_to_load(exp), set([3, 5]))
355
+ self.assertEqual(run.examples_to_load_metadata(exp), set([3, 5]))
356
+
357
+ def test_examples_with_regenerate_example_html_some(self):
358
+ run = Run(
359
+ '/root',
360
+ RunId.from_id('20241102_0'),
361
+ pg.Ref(Suite([
362
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
363
+ ])),
364
+ example_ids=[1, 3, 5],
365
+ reprocess=[1],
366
+ regenerate_example_html=[1, 2, 3],
367
+ )
368
+ exp = run.experiment.leaf_nodes[0]
369
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
370
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
371
+ self.assertEqual(run.examples_to_load(exp), set([2, 3, 5]))
372
+ self.assertEqual(run.examples_to_load_metadata(exp), set([2, 3]))
373
+
273
374
 
274
375
  class RunnerTest(unittest.TestCase):
275
376
 
@@ -172,11 +172,11 @@ class HtmlReporter(experiment_lib.Plugin):
172
172
  )
173
173
  html.save(index_html_path)
174
174
  experiment.info(
175
- f'Generated HTML {index_html_path!r} in {t.elapse:.2f} seconds.',
175
+ f'Updated {index_html_path!r} in {t.elapse:.2f} seconds.',
176
176
  )
177
177
  except BaseException as e: # pylint: disable=broad-except
178
178
  experiment.error(
179
- f'Failed to save HTML {index_html_path!r}. '
179
+ f'Failed to generate {index_html_path!r}. '
180
180
  f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
181
181
  )
182
182
  raise e
@@ -194,26 +194,58 @@ class HtmlReporter(experiment_lib.Plugin):
194
194
  def _save_example_html(
195
195
  self, runner: Runner, experiment: Experiment, example: Example
196
196
  ) -> None:
197
- """Saves the example."""
198
- def _save():
197
+ """Saves the example in HTML format."""
198
+ def _generate():
199
199
  try:
200
- html = example.to_html(
201
- collapse_level=None,
202
- enable_summary_tooltip=False,
203
- extra_flags=dict(
204
- # For properly rendering the next link.
205
- num_examples=getattr(experiment, 'num_examples', None)
206
- ),
207
- )
208
- html.save(
209
- runner.current_run.output_path_for(
210
- experiment, f'{example.id}.html'
211
- )
200
+ with pg.timeit() as t:
201
+ html = example.to_html(
202
+ collapse_level=None,
203
+ enable_summary_tooltip=False,
204
+ extra_flags=dict(
205
+ # For properly rendering the next link.
206
+ num_examples=getattr(experiment, 'num_examples', None)
207
+ ),
208
+ )
209
+ html.save(
210
+ runner.current_run.output_path_for(
211
+ experiment, f'{example.id}.html'
212
+ )
213
+ )
214
+ experiment.info(
215
+ f'\'{example.id}.html\' generated in {t.elapse:.2f} seconds. '
212
216
  )
213
217
  except BaseException as e: # pylint: disable=broad-except
214
218
  experiment.error(
215
- f'Failed to save HTML {example.id}.html. '
219
+ f'Failed to generate \'{example.id}.html\'. '
216
220
  f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
217
221
  )
218
222
  raise e
219
- runner.background_run(_save)
223
+
224
+ def _copy():
225
+ src_file = runner.current_run.input_path_for(
226
+ experiment, f'{example.id}.html'
227
+ )
228
+ dest_file = runner.current_run.output_path_for(
229
+ experiment, f'{example.id}.html'
230
+ )
231
+ if src_file == dest_file:
232
+ return
233
+ try:
234
+ with pg.timeit() as t, pg.io.open(src_file, 'r') as src:
235
+ content = src.read()
236
+ with pg.io.open(dest_file, 'w') as dest:
237
+ dest.write(content)
238
+ experiment.info(
239
+ f'\'{example.id}.html\' copied in {t.elapse:.2f} seconds.'
240
+ )
241
+ except BaseException as e: # pylint: disable=broad-except
242
+ experiment.error(
243
+ f'Failed to copy {src_file!r} to {dest_file!r}. Error: {e}.'
244
+ )
245
+ raise e
246
+
247
+ if example.newly_processed or runner.current_run.regenerate_example_html:
248
+ op = _generate
249
+ else:
250
+ op = _copy
251
+ runner.background_run(op)
@@ -15,6 +15,7 @@ import os
15
15
  import tempfile
16
16
  import unittest
17
17
 
18
+ from langfun.core.eval.v2 import checkpointing
18
19
  from langfun.core.eval.v2 import eval_test_helper
19
20
  from langfun.core.eval.v2 import reporting
20
21
  from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
@@ -26,15 +27,131 @@ class ReportingTest(unittest.TestCase):
26
27
  def test_reporting(self):
27
28
  root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting')
28
29
  experiment = eval_test_helper.test_experiment()
30
+ checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
31
+ reporter = reporting.HtmlReporter()
32
+ run = experiment.run(root_dir, 'new', plugins=[checkpointer, reporter])
33
+ self.assertTrue(
34
+ pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
35
+ )
36
+ for leaf in experiment.leaf_nodes:
37
+ self.assertTrue(
38
+ pg.io.path_exists(run.output_path_for(leaf, 'index.html'))
39
+ )
40
+ for i in range(leaf.num_examples):
41
+ self.assertTrue(
42
+ pg.io.path_exists(run.output_path_for(leaf, f'{i + 1}.html'))
43
+ )
44
+ found_generation_log = False
45
+ for log_entry in leaf._log_entries:
46
+ if 'generated in' in log_entry.message:
47
+ found_generation_log = True
48
+ break
49
+ self.assertTrue(found_generation_log)
50
+
51
+ # Test warm start.
52
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting2')
53
+ experiment = eval_test_helper.test_experiment()
54
+ run = experiment.run(
55
+ root_dir, 'new', plugins=[checkpointer, reporter],
56
+ warm_start_from=run.output_root
57
+ )
58
+ self.assertTrue(
59
+ pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
60
+ )
61
+ for leaf in experiment.leaf_nodes:
62
+ self.assertTrue(
63
+ pg.io.path_exists(run.output_path_for(leaf, 'index.html'))
64
+ )
65
+ for i in range(leaf.num_examples):
66
+ self.assertTrue(
67
+ pg.io.path_exists(run.output_path_for(leaf, f'{i + 1}.html'))
68
+ )
69
+ found_copy_log = False
70
+ for log_entry in leaf._log_entries:
71
+ if 'copied in' in log_entry.message:
72
+ found_copy_log = True
73
+ break
74
+ self.assertTrue(found_copy_log)
75
+
76
+ def test_index_html_generation_error(self):
77
+ root_dir = os.path.join(
78
+ tempfile.gettempdir(),
79
+ 'test_reporting_with_index_html_generation_error'
80
+ )
81
+ experiment = (eval_test_helper
82
+ .test_experiment_with_index_html_generation_error())
29
83
  reporter = reporting.HtmlReporter()
30
84
  run = experiment.run(root_dir, 'new', plugins=[reporter])
31
- pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
85
+ self.assertFalse(
86
+ pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
87
+ )
88
+ for leaf in experiment.leaf_nodes:
89
+ self.assertFalse(
90
+ pg.io.path_exists(run.output_path_for(leaf, 'index.html'))
91
+ )
92
+ found_error_log = False
93
+ for log_entry in experiment._log_entries:
94
+ if log_entry.message.startswith('Failed to generate'):
95
+ found_error_log = True
96
+ break
97
+ self.assertTrue(found_error_log)
98
+
99
+ def test_example_html_generation_error(self):
100
+ root_dir = os.path.join(
101
+ tempfile.gettempdir(),
102
+ 'test_reporting_with_example_html_generation_error'
103
+ )
104
+ experiment = (eval_test_helper
105
+ .test_experiment_with_example_html_generation_error())
106
+ checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
107
+ reporter = reporting.HtmlReporter()
108
+ run = experiment.run(root_dir, 'new', plugins=[checkpointer, reporter])
109
+ self.assertTrue(
110
+ pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
111
+ )
112
+ for leaf in experiment.leaf_nodes:
113
+ self.assertTrue(
114
+ pg.io.path_exists(run.output_path_for(leaf, 'index.html'))
115
+ )
116
+ for i in range(leaf.num_examples):
117
+ self.assertFalse(
118
+ pg.io.path_exists(run.output_path_for(leaf, f'{i + 1}.html'))
119
+ )
120
+ found_error_log = False
121
+ for log_entry in experiment._log_entries:
122
+ if log_entry.message.startswith('Failed to generate'):
123
+ found_error_log = True
124
+ break
125
+ self.assertTrue(found_error_log)
126
+
127
+ # Test warm start.
128
+ root_dir = os.path.join(
129
+ tempfile.gettempdir(),
130
+ 'test_reporting_with_example_html_generation_error2'
131
+ )
132
+ experiment = (eval_test_helper
133
+ .test_experiment_with_example_html_generation_error())
134
+ run = experiment.run(
135
+ root_dir, 'new', plugins=[checkpointer, reporter],
136
+ warm_start_from=run.output_root
137
+ )
138
+ self.assertTrue(
139
+ pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
140
+ )
32
141
  for leaf in experiment.leaf_nodes:
33
142
  self.assertTrue(
34
143
  pg.io.path_exists(run.output_path_for(leaf, 'index.html'))
35
144
  )
36
145
  for i in range(leaf.num_examples):
37
- pg.io.path_exists(run.output_path_for(leaf, f'{i + 1}.html'))
146
+ self.assertFalse(
147
+ pg.io.path_exists(run.output_path_for(leaf, f'{i + 1}.html'))
148
+ )
149
+ found_error_log = False
150
+ for log_entry in experiment._log_entries:
151
+ if log_entry.message.startswith('Failed to copy'):
152
+ found_error_log = True
153
+ break
154
+ self.assertTrue(found_error_log)
38
155
 
39
156
 
40
157
  if __name__ == '__main__':
@@ -123,6 +123,7 @@ class RunnerBase(Runner):
123
123
  def on_experiment_start(self, experiment: Experiment) -> None:
124
124
  """Called when an evaluation is started."""
125
125
  # Start the progress of the evaluation.
126
+ num_examples_to_evaluate = 0
126
127
  if experiment.is_leaf:
127
128
  assert isinstance(experiment, Evaluation)
128
129
  num_examples_to_evaluate = (
@@ -130,10 +131,6 @@ class RunnerBase(Runner):
130
131
  if self.current_run.example_ids else experiment.num_examples
131
132
  )
132
133
  experiment.progress.start(total=num_examples_to_evaluate)
133
- experiment.info(
134
- 'Starting evaluation %s with %d examples to evaluate.'
135
- % (experiment.id, num_examples_to_evaluate)
136
- )
137
134
  else:
138
135
  experiment.progress.start(total=len(experiment.leaf_nodes))
139
136
 
@@ -141,6 +138,12 @@ class RunnerBase(Runner):
141
138
  for plugin in self._all_plugins(experiment):
142
139
  plugin.on_experiment_start(self, experiment)
143
140
 
141
+ if experiment.is_leaf:
142
+ experiment.info(
143
+ f'Starting evaluation {experiment.id!r} with '
144
+ f'{num_examples_to_evaluate} examples to evaluate.'
145
+ )
146
+
144
147
  def on_experiment_skipped(self, experiment: Experiment) -> None:
145
148
  """Called when an evaluation is skipped."""
146
149
  # Skip event will only be triggered for leaf evaluations.
@@ -45,6 +45,7 @@ from langfun.core.llms.google_genai import Palm2_IT
45
45
  # OpenAI models.
46
46
  from langfun.core.llms.openai import OpenAI
47
47
 
48
+ from langfun.core.llms.openai import GptO1
48
49
  from langfun.core.llms.openai import GptO1Preview
49
50
  from langfun.core.llms.openai import GptO1Preview_20240912
50
51
  from langfun.core.llms.openai import GptO1Mini
@@ -106,6 +107,7 @@ from langfun.core.llms.anthropic import VertexAIAnthropic
106
107
  from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20241022
107
108
  from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20240620
108
109
  from langfun.core.llms.anthropic import VertexAIClaude3_5_Haiku_20241022
110
+ from langfun.core.llms.anthropic import VertexAIClaude3_Opus_20240229
109
111
 
110
112
  from langfun.core.llms.groq import Groq
111
113
  from langfun.core.llms.groq import GroqLlama3_2_3B
@@ -67,6 +67,13 @@ SUPPORTED_MODELS_AND_SETTINGS = {
67
67
  cost_per_1k_input_tokens=0.001,
68
68
  cost_per_1k_output_tokens=0.005,
69
69
  ),
70
+ 'claude-3-opus@20240229': pg.Dict(
71
+ max_tokens=4096,
72
+ rpm=4000,
73
+ tpm=400000,
74
+ cost_per_1k_input_tokens=0.015,
75
+ cost_per_1k_output_tokens=0.075,
76
+ ),
70
77
  # Anthropic hosted models.
71
78
  'claude-3-5-sonnet-20241022': pg.Dict(
72
79
  max_tokens=8192,
@@ -461,6 +468,11 @@ class VertexAIAnthropic(Anthropic):
461
468
  return request
462
469
 
463
470
 
471
+ class VertexAIClaude3_Opus_20240229(VertexAIAnthropic): # pylint: disable=invalid-name
472
+ """Anthropic's Claude 3 Opus model on VertexAI."""
473
+ model = 'claude-3-opus@20240229'
474
+
475
+
464
476
  class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
465
477
  """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
466
478
  model = 'claude-3-5-sonnet-v2@20241022'
@@ -32,6 +32,13 @@ SUPPORTED_MODELS_AND_SETTINGS = {
32
32
  # o1 (preview) models.
33
33
  # Pricing in US dollars, from https://openai.com/api/pricing/
34
34
  # as of 2024-10-10.
35
+ 'o1': pg.Dict(
36
+ in_service=True,
37
+ rpm=10000,
38
+ tpm=5000000,
39
+ cost_per_1k_input_tokens=0.015,
40
+ cost_per_1k_output_tokens=0.06,
41
+ ),
35
42
  'o1-preview': pg.Dict(
36
43
  in_service=True,
37
44
  rpm=10000,
@@ -255,25 +262,17 @@ SUPPORTED_MODELS_AND_SETTINGS = {
255
262
  ),
256
263
  # GPT-3.5 models
257
264
  'text-davinci-003': pg.Dict(
258
- in_service=False,
259
- rpm=_DEFAULT_RPM,
260
- tpm=_DEFAULT_TPM
265
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
261
266
  ),
262
267
  'text-davinci-002': pg.Dict(
263
- in_service=False,
264
- rpm=_DEFAULT_RPM,
265
- tpm=_DEFAULT_TPM
268
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
266
269
  ),
267
270
  'code-davinci-002': pg.Dict(
268
- in_service=False,
269
- rpm=_DEFAULT_RPM,
270
- tpm=_DEFAULT_TPM
271
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
271
272
  ),
272
273
  # GPT-3 instruction-tuned models (Deprecated)
273
274
  'text-curie-001': pg.Dict(
274
- in_service=False,
275
- rpm=_DEFAULT_RPM,
276
- tpm=_DEFAULT_TPM
275
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
277
276
  ),
278
277
  'text-babbage-001': pg.Dict(
279
278
  in_service=False,
@@ -290,32 +289,12 @@ SUPPORTED_MODELS_AND_SETTINGS = {
290
289
  rpm=_DEFAULT_RPM,
291
290
  tpm=_DEFAULT_TPM,
292
291
  ),
293
- 'curie': pg.Dict(
294
- in_service=False,
295
- rpm=_DEFAULT_RPM,
296
- tpm=_DEFAULT_TPM
297
- ),
298
- 'babbage': pg.Dict(
299
- in_service=False,
300
- rpm=_DEFAULT_RPM,
301
- tpm=_DEFAULT_TPM
302
- ),
303
- 'ada': pg.Dict(
304
- in_service=False,
305
- rpm=_DEFAULT_RPM,
306
- tpm=_DEFAULT_TPM
307
- ),
292
+ 'curie': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
293
+ 'babbage': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
294
+ 'ada': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
308
295
  # GPT-3 base models that are still in service.
309
- 'babbage-002': pg.Dict(
310
- in_service=True,
311
- rpm=_DEFAULT_RPM,
312
- tpm=_DEFAULT_TPM
313
- ),
314
- 'davinci-002': pg.Dict(
315
- in_service=True,
316
- rpm=_DEFAULT_RPM,
317
- tpm=_DEFAULT_TPM
318
- ),
296
+ 'babbage-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
297
+ 'davinci-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
319
298
  }
320
299
 
321
300
 
@@ -569,6 +548,13 @@ class OpenAI(rest.REST):
569
548
  )
570
549
 
571
550
 
551
+ class GptO1(OpenAI):
552
+ """GPT-O1."""
553
+
554
+ model = 'o1'
555
+ multimodal = True
556
+
557
+
572
558
  class GptO1Preview(OpenAI):
573
559
  """GPT-O1."""
574
560
  model = 'o1-preview'