langfun 0.1.2.dev202506240804__py3-none-any.whl → 0.1.2.dev202506260804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/language_model.py +56 -5
- langfun/core/language_model_test.py +22 -0
- {langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/RECORD +7 -7
- {langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -769,8 +769,25 @@ class LanguageModel(component.Component):
|
|
769
769
|
cls._MODEL_FACTORY[model_id_or_prefix] = factory
|
770
770
|
|
771
771
|
@classmethod
|
772
|
-
def get(cls,
|
773
|
-
"""Creates a language model instance from a model
|
772
|
+
def get(cls, model_str: str, *args, **kwargs) -> 'LanguageModel':
|
773
|
+
"""Creates a language model instance from a model str.
|
774
|
+
|
775
|
+
Args:
|
776
|
+
model_str: A string that identifies the model. It can be a model ID or a
|
777
|
+
model ID with kwargs.
|
778
|
+
For example, "gpt-o3?temperature=0.1&n=2" will create a GPT-o3 model
|
779
|
+
with temperature set to 0.1 and n set to 2.
|
780
|
+
*args: Additional arguments to pass to the model factory.
|
781
|
+
**kwargs: Additional keyword arguments to pass to the model factory.
|
782
|
+
kwargs provided here will take precedence over kwargs parsed from
|
783
|
+
model_str.
|
784
|
+
|
785
|
+
Returns:
|
786
|
+
A language model instance.
|
787
|
+
"""
|
788
|
+
model_id, model_kwargs = cls._parse_model_str(model_str)
|
789
|
+
model_kwargs.update(kwargs)
|
790
|
+
|
774
791
|
factory = cls._MODEL_FACTORY.get(model_id)
|
775
792
|
if factory is None:
|
776
793
|
factories = []
|
@@ -786,11 +803,45 @@ class LanguageModel(component.Component):
|
|
786
803
|
'Please specify a more specific model ID.'
|
787
804
|
)
|
788
805
|
factory = factories[0][1]
|
789
|
-
return factory(model_id, *args, **
|
806
|
+
return factory(model_id, *args, **model_kwargs)
|
790
807
|
|
791
808
|
@classmethod
|
792
|
-
def
|
793
|
-
|
809
|
+
def _parse_model_str(cls, model_str: str) -> tuple[str, dict[str, Any]]:
|
810
|
+
"""Parses a model string into model ID and kwargs."""
|
811
|
+
parts = model_str.split('?')
|
812
|
+
if len(parts) == 1:
|
813
|
+
return model_str, {}
|
814
|
+
elif len(parts) == 2:
|
815
|
+
model_id, kwargs_str = parts
|
816
|
+
kwargs = {}
|
817
|
+
for kv in kwargs_str.split('&'):
|
818
|
+
kv_parts = kv.split('=')
|
819
|
+
if len(kv_parts) != 2:
|
820
|
+
raise ValueError(f'Invalid kwargs in model string: {model_str!r}.')
|
821
|
+
k, v = kv_parts
|
822
|
+
if v.isnumeric():
|
823
|
+
v = int(v)
|
824
|
+
elif v.lower() in ('true', 'false'):
|
825
|
+
v = v.lower() == 'true'
|
826
|
+
else:
|
827
|
+
v = v.strip()
|
828
|
+
try:
|
829
|
+
v = float(v)
|
830
|
+
except ValueError:
|
831
|
+
pass
|
832
|
+
kwargs[k] = v
|
833
|
+
return model_id, kwargs
|
834
|
+
else:
|
835
|
+
raise ValueError(f'Invalid model string: {model_str!r}.')
|
836
|
+
|
837
|
+
@classmethod
|
838
|
+
def dir(cls, regex: str | None = None):
|
839
|
+
"""Returns a list of model IDs that match the given regex."""
|
840
|
+
if regex is None:
|
841
|
+
return sorted(list(LanguageModel._MODEL_FACTORY.keys()))
|
842
|
+
return sorted(
|
843
|
+
[k for k in LanguageModel._MODEL_FACTORY.keys() if re.match(regex, k)]
|
844
|
+
)
|
794
845
|
|
795
846
|
@pg.explicit_method_override
|
796
847
|
def __init__(self, *args, **kwargs) -> None:
|
@@ -196,6 +196,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
196
196
|
self.assertEqual(lm.failures_before_attempt, 1)
|
197
197
|
self.assertEqual(lm.sampling_options.temperature, 0.2)
|
198
198
|
self.assertIn('MockModel', lm_lib.LanguageModel.dir())
|
199
|
+
self.assertIn('MockModel', lm_lib.LanguageModel.dir('Mock.*'))
|
200
|
+
self.assertEqual(lm_lib.LanguageModel.dir('NotMock.*'), [])
|
199
201
|
|
200
202
|
lm_lib.LanguageModel.register('mock://.*', mock_model)
|
201
203
|
lm_lib.LanguageModel.register('mock.*', mock_model)
|
@@ -206,6 +208,26 @@ class LanguageModelTest(unittest.TestCase):
|
|
206
208
|
with self.assertRaisesRegex(ValueError, 'Model not found'):
|
207
209
|
lm_lib.LanguageModel.get('non-existent://test2')
|
208
210
|
|
211
|
+
lm = lm_lib.LanguageModel.get('MockModel?temperature=0.1')
|
212
|
+
self.assertEqual(lm.sampling_options.temperature, 0.1)
|
213
|
+
|
214
|
+
lm = lm_lib.LanguageModel.get(
|
215
|
+
'MockModel?temperature=0.1&name=my_model&logprobs=true&n=2'
|
216
|
+
)
|
217
|
+
self.assertEqual(lm.sampling_options.temperature, 0.1)
|
218
|
+
self.assertEqual(lm.sampling_options.n, 2)
|
219
|
+
self.assertTrue(lm.sampling_options.logprobs)
|
220
|
+
self.assertEqual(lm.name, 'my_model')
|
221
|
+
|
222
|
+
lm = lm_lib.LanguageModel.get('MockModel?temperature=0.1', temperature=0.2)
|
223
|
+
self.assertEqual(lm.sampling_options.temperature, 0.2)
|
224
|
+
|
225
|
+
with self.assertRaisesRegex(ValueError, 'Invalid model string'):
|
226
|
+
lm_lib.LanguageModel.get('MockModel??')
|
227
|
+
|
228
|
+
with self.assertRaisesRegex(ValueError, 'Invalid kwargs in model string'):
|
229
|
+
lm_lib.LanguageModel.get('MockModel?temperature=0.1&')
|
230
|
+
|
209
231
|
def test_basics(self):
|
210
232
|
lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
|
211
233
|
self.assertEqual(
|
@@ -8,8 +8,8 @@ langfun/core/console.py,sha256=cLQEf84aDxItA9fStJV22xJch0TqFLNf9hLqwJ0RHmU,2652
|
|
8
8
|
langfun/core/console_test.py,sha256=pBOcuNMJdVELywvroptfcRtJMsegMm3wSlHAL2TdxVk,1679
|
9
9
|
langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
|
10
10
|
langfun/core/langfunc_test.py,sha256=CDn-gJCa5EnjN7cotAVCfSCbuzddq2o-HzEt7kV8HbY,8882
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=fJeYDz_TD1feiUvSysXXeo2bV-cq5T34HWeYgsTICP4,49680
|
12
|
+
langfun/core/language_model_test.py,sha256=VyiHwrUtJGkoLyzsjhVqawijwtwoRqsYOvQD57n8Iv8,37413
|
13
13
|
langfun/core/logging.py,sha256=7IGAhp7mGokZxxqtL-XZvFLKaZ5k3F5_Xp2NUtR4GwE,9136
|
14
14
|
langfun/core/logging_test.py,sha256=vbVGOQxwMmVSiFfbt2897gUt-8nqDpV64jCAeUG_q5U,6924
|
15
15
|
langfun/core/memory.py,sha256=vyXVvfvSdLLJAzdIupnbn3k26OgclCx-OJ7gddS5e1Y,2070
|
@@ -156,8 +156,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
156
156
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
157
157
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
158
158
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
159
|
-
langfun-0.1.2.
|
160
|
-
langfun-0.1.2.
|
161
|
-
langfun-0.1.2.
|
162
|
-
langfun-0.1.2.
|
163
|
-
langfun-0.1.2.
|
159
|
+
langfun-0.1.2.dev202506260804.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
160
|
+
langfun-0.1.2.dev202506260804.dist-info/METADATA,sha256=NQpc4kGauaQeKVVYTyt65znOYgwycKLe6K_HpBpPnok,8178
|
161
|
+
langfun-0.1.2.dev202506260804.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
162
|
+
langfun-0.1.2.dev202506260804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
163
|
+
langfun-0.1.2.dev202506260804.dist-info/RECORD,,
|
File without changes
|
{langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
{langfun-0.1.2.dev202506240804.dist-info → langfun-0.1.2.dev202506260804.dist-info}/top_level.txt
RENAMED
File without changes
|