diff --git a/src/providers/blenderbot.py b/src/providers/blenderbot.py index f55b8df..23da087 100644 --- a/src/providers/blenderbot.py +++ b/src/providers/blenderbot.py @@ -1,7 +1,7 @@ -from .hfbasechat import BaseHFChatProvider +from .hfbasechat import BaseHFChatProvider, ProviderType class BlenderBotProvider(BaseHFChatProvider): name = "BlenderBot" description = "An open domain chatbot" provider = "facebook/blenderbot-400M-distill" - + provider_type = ProviderType.TEXT \ No newline at end of file diff --git a/src/providers/dialogpt.py b/src/providers/dialogpt.py index ad13d27..19fe982 100644 --- a/src/providers/dialogpt.py +++ b/src/providers/dialogpt.py @@ -1,6 +1,7 @@ -from .hfbasechat import BaseHFChatProvider +from .hfbasechat import BaseHFChatProvider, ProviderType class DialoGPTProvider(BaseHFChatProvider): name = "DialoGPT" description = "A State-of-the-Art Large-scale Pretrained Response generation model" provider = "microsoft/DialoGPT-large" + provider_type = ProviderType.CHAT diff --git a/src/providers/googleflant5xxl.py b/src/providers/googleflant5xxl.py index a146b06..96f4933 100644 --- a/src/providers/googleflant5xxl.py +++ b/src/providers/googleflant5xxl.py @@ -1,7 +1,7 @@ -from .hfbasechat import BaseHFChatProvider +from .hfbasechat import BaseHFChatProvider, ProviderType class GoogleFlant5XXLProvider(BaseHFChatProvider): name = "Google Flan T5 XXL" description = "A better Text-To-Text Transfer Transformer (T5) model" provider = "google/flan-t5-xxl" - chat_mode = False + provider_type = ProviderType.TEXT diff --git a/src/providers/gpt2.py b/src/providers/gpt2.py index de0088f..039210e 100644 --- a/src/providers/gpt2.py +++ b/src/providers/gpt2.py @@ -1,7 +1,7 @@ -from .hfbasechat import BaseHFChatProvider +from .hfbasechat import BaseHFChatProvider, ProviderType class GPT2Provider(BaseHFChatProvider): name = "GPT 2" description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion" provider = "gpt2" - chat_mode = False + provider_type = ProviderType.TEXT diff --git a/src/providers/hfbasechat.py b/src/providers/hfbasechat.py index d7cf147..160ea49 100644 --- a/src/providers/hfbasechat.py +++ b/src/providers/hfbasechat.py @@ -1,4 +1,4 @@ -from .base import BaseProvider +from .base import BaseProvider, ProviderType import requests @@ -23,7 +23,7 @@ class BaseHFChatProvider(BaseProvider): return response.json() - if self.chat_mode: + if self.provider_type == ProviderType.CHAT: output = query({ "inputs": { "past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],