langfun 0.1.2.dev202506250804__py3-none-any.whl → 0.1.2.dev202506260804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -769,8 +769,25 @@ class LanguageModel(component.Component):
769
769
  cls._MODEL_FACTORY[model_id_or_prefix] = factory
770
770
 
771
771
  @classmethod
772
- def get(cls, model_id: str, *args, **kwargs):
773
- """Creates a language model instance from a model ID."""
772
+ def get(cls, model_str: str, *args, **kwargs) -> 'LanguageModel':
773
+ """Creates a language model instance from a model str.
774
+
775
+ Args:
776
+ model_str: A string that identifies the model. It can be a model ID or a
777
+ model ID with kwargs.
778
+ For example, "gpt-o3?temperature=0.1&n=2" will create a GPT-o3 model
779
+ with temperature set to 0.1 and n set to 2.
780
+ *args: Additional arguments to pass to the model factory.
781
+ **kwargs: Additional keyword arguments to pass to the model factory.
782
+ kwargs provided here will take precedence over kwargs parsed from
783
+ model_str.
784
+
785
+ Returns:
786
+ A language model instance.
787
+ """
788
+ model_id, model_kwargs = cls._parse_model_str(model_str)
789
+ model_kwargs.update(kwargs)
790
+
774
791
  factory = cls._MODEL_FACTORY.get(model_id)
775
792
  if factory is None:
776
793
  factories = []
@@ -786,11 +803,45 @@ class LanguageModel(component.Component):
786
803
  'Please specify a more specific model ID.'
787
804
  )
788
805
  factory = factories[0][1]
789
- return factory(model_id, *args, **kwargs)
806
+ return factory(model_id, *args, **model_kwargs)
790
807
 
791
808
  @classmethod
792
- def dir(cls):
793
- return sorted(list(LanguageModel._MODEL_FACTORY.keys()))
809
+ def _parse_model_str(cls, model_str: str) -> tuple[str, dict[str, Any]]:
810
+ """Parses a model string into model ID and kwargs."""
811
+ parts = model_str.split('?')
812
+ if len(parts) == 1:
813
+ return model_str, {}
814
+ elif len(parts) == 2:
815
+ model_id, kwargs_str = parts
816
+ kwargs = {}
817
+ for kv in kwargs_str.split('&'):
818
+ kv_parts = kv.split('=')
819
+ if len(kv_parts) != 2:
820
+ raise ValueError(f'Invalid kwargs in model string: {model_str!r}.')
821
+ k, v = kv_parts
822
+ if v.isnumeric():
823
+ v = int(v)
824
+ elif v.lower() in ('true', 'false'):
825
+ v = v.lower() == 'true'
826
+ else:
827
+ v = v.strip()
828
+ try:
829
+ v = float(v)
830
+ except ValueError:
831
+ pass
832
+ kwargs[k] = v
833
+ return model_id, kwargs
834
+ else:
835
+ raise ValueError(f'Invalid model string: {model_str!r}.')
836
+
837
+ @classmethod
838
+ def dir(cls, regex: str | None = None):
839
+ """Returns a list of model IDs that match the given regex."""
840
+ if regex is None:
841
+ return sorted(list(LanguageModel._MODEL_FACTORY.keys()))
842
+ return sorted(
843
+ [k for k in LanguageModel._MODEL_FACTORY.keys() if re.match(regex, k)]
844
+ )
794
845
 
795
846
  @pg.explicit_method_override
796
847
  def __init__(self, *args, **kwargs) -> None:
@@ -196,6 +196,8 @@ class LanguageModelTest(unittest.TestCase):
196
196
  self.assertEqual(lm.failures_before_attempt, 1)
197
197
  self.assertEqual(lm.sampling_options.temperature, 0.2)
198
198
  self.assertIn('MockModel', lm_lib.LanguageModel.dir())
199
+ self.assertIn('MockModel', lm_lib.LanguageModel.dir('Mock.*'))
200
+ self.assertEqual(lm_lib.LanguageModel.dir('NotMock.*'), [])
199
201
 
200
202
  lm_lib.LanguageModel.register('mock://.*', mock_model)
201
203
  lm_lib.LanguageModel.register('mock.*', mock_model)
@@ -206,6 +208,26 @@ class LanguageModelTest(unittest.TestCase):
206
208
  with self.assertRaisesRegex(ValueError, 'Model not found'):
207
209
  lm_lib.LanguageModel.get('non-existent://test2')
208
210
 
211
+ lm = lm_lib.LanguageModel.get('MockModel?temperature=0.1')
212
+ self.assertEqual(lm.sampling_options.temperature, 0.1)
213
+
214
+ lm = lm_lib.LanguageModel.get(
215
+ 'MockModel?temperature=0.1&name=my_model&logprobs=true&n=2'
216
+ )
217
+ self.assertEqual(lm.sampling_options.temperature, 0.1)
218
+ self.assertEqual(lm.sampling_options.n, 2)
219
+ self.assertTrue(lm.sampling_options.logprobs)
220
+ self.assertEqual(lm.name, 'my_model')
221
+
222
+ lm = lm_lib.LanguageModel.get('MockModel?temperature=0.1', temperature=0.2)
223
+ self.assertEqual(lm.sampling_options.temperature, 0.2)
224
+
225
+ with self.assertRaisesRegex(ValueError, 'Invalid model string'):
226
+ lm_lib.LanguageModel.get('MockModel??')
227
+
228
+ with self.assertRaisesRegex(ValueError, 'Invalid kwargs in model string'):
229
+ lm_lib.LanguageModel.get('MockModel?temperature=0.1&')
230
+
209
231
  def test_basics(self):
210
232
  lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
211
233
  self.assertEqual(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langfun
3
- Version: 0.1.2.dev202506250804
3
+ Version: 0.1.2.dev202506260804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -8,8 +8,8 @@ langfun/core/console.py,sha256=cLQEf84aDxItA9fStJV22xJch0TqFLNf9hLqwJ0RHmU,2652
8
8
  langfun/core/console_test.py,sha256=pBOcuNMJdVELywvroptfcRtJMsegMm3wSlHAL2TdxVk,1679
9
9
  langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
10
10
  langfun/core/langfunc_test.py,sha256=CDn-gJCa5EnjN7cotAVCfSCbuzddq2o-HzEt7kV8HbY,8882
11
- langfun/core/language_model.py,sha256=GCvbu749TviZyLQSNunO0rKeDAb7E_6G4rzqQJerN_E,47913
12
- langfun/core/language_model_test.py,sha256=iA5uo7rIj2jAtCYzMzhyNg1fWqE2Onn60bOO58q72C0,36454
11
+ langfun/core/language_model.py,sha256=fJeYDz_TD1feiUvSysXXeo2bV-cq5T34HWeYgsTICP4,49680
12
+ langfun/core/language_model_test.py,sha256=VyiHwrUtJGkoLyzsjhVqawijwtwoRqsYOvQD57n8Iv8,37413
13
13
  langfun/core/logging.py,sha256=7IGAhp7mGokZxxqtL-XZvFLKaZ5k3F5_Xp2NUtR4GwE,9136
14
14
  langfun/core/logging_test.py,sha256=vbVGOQxwMmVSiFfbt2897gUt-8nqDpV64jCAeUG_q5U,6924
15
15
  langfun/core/memory.py,sha256=vyXVvfvSdLLJAzdIupnbn3k26OgclCx-OJ7gddS5e1Y,2070
@@ -156,8 +156,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
156
156
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
157
157
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
158
158
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
159
- langfun-0.1.2.dev202506250804.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
160
- langfun-0.1.2.dev202506250804.dist-info/METADATA,sha256=cB6OfkCuab3YdlkzG7LsZXsRx72XNALqeqnAbm3G6oI,8178
161
- langfun-0.1.2.dev202506250804.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
- langfun-0.1.2.dev202506250804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
163
- langfun-0.1.2.dev202506250804.dist-info/RECORD,,
159
+ langfun-0.1.2.dev202506260804.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
160
+ langfun-0.1.2.dev202506260804.dist-info/METADATA,sha256=NQpc4kGauaQeKVVYTyt65znOYgwycKLe6K_HpBpPnok,8178
161
+ langfun-0.1.2.dev202506260804.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
+ langfun-0.1.2.dev202506260804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
163
+ langfun-0.1.2.dev202506260804.dist-info/RECORD,,