Add provider type to chat providers
This commit is contained in:
parent
0123a80883
commit
02b02edbbb
|
@ -1,7 +1,7 @@
|
||||||
from .hfbasechat import BaseHFChatProvider
|
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||||
|
|
||||||
class BlenderBotProvider(BaseHFChatProvider):
|
class BlenderBotProvider(BaseHFChatProvider):
|
||||||
name = "BlenderBot"
|
name = "BlenderBot"
|
||||||
description = "An open domain chatbot"
|
description = "An open domain chatbot"
|
||||||
provider = "facebook/blenderbot-400M-distill"
|
provider = "facebook/blenderbot-400M-distill"
|
||||||
|
provider_type = ProviderType.TEXT
|
|
@ -1,6 +1,7 @@
|
||||||
from .hfbasechat import BaseHFChatProvider
|
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||||
|
|
||||||
class DialoGPTProvider(BaseHFChatProvider):
|
class DialoGPTProvider(BaseHFChatProvider):
|
||||||
name = "DialoGPT"
|
name = "DialoGPT"
|
||||||
description = "A State-of-the-Art Large-scale Pretrained Response generation model"
|
description = "A State-of-the-Art Large-scale Pretrained Response generation model"
|
||||||
provider = "microsoft/DialoGPT-large"
|
provider = "microsoft/DialoGPT-large"
|
||||||
|
provider_type = ProviderType.CHAT
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .hfbasechat import BaseHFChatProvider
|
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||||
|
|
||||||
class GoogleFlant5XXLProvider(BaseHFChatProvider):
|
class GoogleFlant5XXLProvider(BaseHFChatProvider):
|
||||||
name = "Google Flan T5 XXL"
|
name = "Google Flan T5 XXL"
|
||||||
description = "A better Text-To-Text Transfer Transformer (T5) model"
|
description = "A better Text-To-Text Transfer Transformer (T5) model"
|
||||||
provider = "google/flan-t5-xxl"
|
provider = "google/flan-t5-xxl"
|
||||||
chat_mode = False
|
provider_type = ProviderType.TEXT
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .hfbasechat import BaseHFChatProvider
|
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||||
|
|
||||||
class GPT2Provider(BaseHFChatProvider):
|
class GPT2Provider(BaseHFChatProvider):
|
||||||
name = "GPT 2"
|
name = "GPT 2"
|
||||||
description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion"
|
description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion"
|
||||||
provider = "gpt2"
|
provider = "gpt2"
|
||||||
chat_mode = False
|
provider_type = ProviderType.TEXT
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider, ProviderType
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class BaseHFChatProvider(BaseProvider):
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
if self.chat_mode:
|
if self.provider_type == ProviderType.CHAT:
|
||||||
output = query({
|
output = query({
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
|
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
|
||||||
|
|
Loading…
Reference in New Issue