Add provider type to chat providers

This commit is contained in:
0xmrtt 2024-02-25 12:22:27 +01:00
parent 0123a80883
commit 02b02edbbb
5 changed files with 10 additions and 9 deletions

View File

@ -1,7 +1,7 @@
from .hfbasechat import BaseHFChatProvider
from .hfbasechat import BaseHFChatProvider, ProviderType
class BlenderBotProvider(BaseHFChatProvider):
name = "BlenderBot"
description = "An open domain chatbot"
provider = "facebook/blenderbot-400M-distill"
provider_type = ProviderType.TEXT

View File

@ -1,6 +1,7 @@
from .hfbasechat import BaseHFChatProvider
from .hfbasechat import BaseHFChatProvider, ProviderType
class DialoGPTProvider(BaseHFChatProvider):
name = "DialoGPT"
description = "A State-of-the-Art Large-scale Pretrained Response generation model"
provider = "microsoft/DialoGPT-large"
provider_type = ProviderType.CHAT

View File

@ -1,7 +1,7 @@
from .hfbasechat import BaseHFChatProvider
from .hfbasechat import BaseHFChatProvider, ProviderType
class GoogleFlant5XXLProvider(BaseHFChatProvider):
name = "Google Flan T5 XXL"
description = "A better Text-To-Text Transfer Transformer (T5) model"
provider = "google/flan-t5-xxl"
chat_mode = False
provider_type = ProviderType.TEXT

View File

@ -1,7 +1,7 @@
from .hfbasechat import BaseHFChatProvider
from .hfbasechat import BaseHFChatProvider, ProviderType
class GPT2Provider(BaseHFChatProvider):
name = "GPT 2"
description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion"
provider = "gpt2"
chat_mode = False
provider_type = ProviderType.TEXT

View File

@ -1,4 +1,4 @@
from .base import BaseProvider
from .base import BaseProvider, ProviderType
import requests
@ -23,7 +23,7 @@ class BaseHFChatProvider(BaseProvider):
return response.json()
if self.chat_mode:
if self.provider_type == ProviderType.CHAT:
output = query({
"inputs": {
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],