feat: add more providers
This commit is contained in:
parent
fb08eaa9fe
commit
b438eb1077
|
@ -46,8 +46,8 @@
|
|||
"sources": [
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/e6/02/a2cff6306177ae6bc73bc0665065de51dfb3b9db7373e122e2735faf0d97/tqdm-4.65.0-py3-none-any.whl",
|
||||
"sha256": "c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"
|
||||
"url": "https://files.pythonhosted.org/packages/00/e5/f12a80907d0884e6dff9c16d0c0114d81b8cd07dc3ae54c5e962cc83037e/tqdm-4.66.1-py3-none-any.whl",
|
||||
"sha256": "d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -106,6 +106,85 @@
|
|||
"sha256": "b4246fb7677d3b98f501a39d43396d3cafdc8eadb045f4a31be01863f655c610"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "python3-openai",
|
||||
"buildsystem": "simple",
|
||||
"build-commands": [
|
||||
"pip3 install --verbose --exists-action=i --no-index --find-links=\"file://${PWD}\" --prefix=${FLATPAK_DEST} \"openai\" --no-build-isolation"
|
||||
],
|
||||
"sources": [
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/d6/12/6fc7c7dcc84e263940e87cbafca17c1ef28f39dae6c0b10f51e4ccc764ee/aiohttp-3.8.5.tar.gz",
|
||||
"sha256": "b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl",
|
||||
"sha256": "f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl",
|
||||
"sha256": "7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/f0/eb/fcb708c7bf5056045e9e98f62b93bd7467eb718b0202e7698eb11d66416c/attrs-23.1.0-py3-none-any.whl",
|
||||
"sha256": "1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/4c/dd/2234eab22353ffc7d94e8d13177aaa050113286e93e7b40eae01fbf7c3d9/certifi-2023.7.22-py3-none-any.whl",
|
||||
"sha256": "92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/2a/53/cf0a48de1bdcf6ff6e1c9a023f5f523dfe303e4024f216feac64b6eb7f67/charset-normalizer-3.2.0.tar.gz",
|
||||
"sha256": "3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/8c/1f/49c96ccc87127682ba900b092863ef7c20302a2144b3185412a08480ca22/frozenlist-1.4.0.tar.gz",
|
||||
"sha256": "09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/fc/34/3030de6f1370931b9dbb4dad48f6ab1015ab1d32447850b9fc94e60097be/idna-3.4-py3-none-any.whl",
|
||||
"sha256": "90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/4a/15/bd620f7a6eb9aa5112c4ef93e7031bcd071e0611763d8e17706ef8ba65e0/multidict-6.0.4.tar.gz",
|
||||
"sha256": "3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/67/78/7588a047e458cb8075a4089d721d7af5e143ff85a2388d4a28c530be0494/openai-0.27.8-py3-none-any.whl",
|
||||
"sha256": "e0a7c2f7da26bdbe5354b03c6d4b82a2f34bd4458c7a17ae1a7092c3e397e03c"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/70/8e/0e2d847013cb52cd35b38c009bb167a1a26b2ce6cd6965bf26b47bc0bf44/requests-2.31.0-py3-none-any.whl",
|
||||
"sha256": "58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/00/e5/f12a80907d0884e6dff9c16d0c0114d81b8cd07dc3ae54c5e962cc83037e/tqdm-4.66.1-py3-none-any.whl",
|
||||
"sha256": "d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/9b/81/62fd61001fa4b9d0df6e31d47ff49cfa9de4af03adecf339c7bc30656b37/urllib3-2.0.4-py3-none-any.whl",
|
||||
"sha256": "de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"url": "https://files.pythonhosted.org/packages/5f/3f/04b3c5e57844fb9c034b09c5cb6d2b43de5d64a093c30529fd233e16cf09/yarl-1.9.2.tar.gz",
|
||||
"sha256": "04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -7,4 +7,5 @@ tqdm
|
|||
charset-normalizer
|
||||
idna
|
||||
urllib3
|
||||
Babel
|
||||
Babel
|
||||
openai
|
|
@ -129,6 +129,8 @@ class BavarderApplication(Adw.Application):
|
|||
|
||||
def save(self):
|
||||
with open(self.data_path, "w", encoding="utf-8") as f:
|
||||
for name, d in self.data["providers"].items():
|
||||
print(d)
|
||||
self.data = json.dump(self.data, f)
|
||||
self.settings.set_boolean("local-mode", self.local_mode)
|
||||
self.settings.set_string("current-provider", self.current_provider)
|
||||
|
@ -173,7 +175,7 @@ class BavarderApplication(Adw.Application):
|
|||
self.providers = {}
|
||||
|
||||
for provider in PROVIDERS:
|
||||
p = provider(self, self.win, self.data["providers"])
|
||||
p = provider(self, self.win)
|
||||
|
||||
self.providers[p.slug] = p
|
||||
|
||||
|
|
|
@ -1,9 +1,23 @@
|
|||
from .blenderbot import BlenderBotProvider
|
||||
from .catgpt import CatGPTProvider
|
||||
from .dialogpt import DialoGPTProvider
|
||||
from .stablebeluga2 import StableBeluga2Provider
|
||||
from .openaigpt35turbo import OpenAIGPT35TurboProvider
|
||||
from .googleflant5xxl import GoogleFlant5XXLProvider
|
||||
from .openaigpt4 import OpenAIGPT4Provider
|
||||
from .gpt2 import GPT2Provider
|
||||
from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvider
|
||||
from .robertasquad2 import RobertaSquad2Provider
|
||||
|
||||
PROVIDERS = {
|
||||
BlenderBotProvider,
|
||||
CatGPTProvider,
|
||||
DialoGPTProvider
|
||||
DialoGPTProvider,
|
||||
OpenAIGPT35TurboProvider,
|
||||
OpenAIGPT4Provider,
|
||||
GoogleFlant5XXLProvider,
|
||||
GPT2Provider,
|
||||
# StableBeluga2Provider,
|
||||
# HuggingFaceOpenAssistantSFT1PythiaProvider,
|
||||
# RobertaSquad2Provider
|
||||
}
|
|
@ -14,34 +14,38 @@ class BaseProvider:
|
|||
data: Dict[str, str] = {}
|
||||
has_auth: bool = False
|
||||
require_authentification: bool = False
|
||||
base_url = "https://bavarder.codeberg.page/providers/"
|
||||
|
||||
def __init__(self, app, window, providers):
|
||||
def __init__(self, app, window):
|
||||
self.slug = self.slugify(self.name)
|
||||
self.copyright = f"© 2023 {self.developer_name}"
|
||||
self.url = f"https://bavarder.codeberg.page/providers/{self.slug}"
|
||||
self.url = f"{self.base_url}{self.slug}"
|
||||
|
||||
self.app = app
|
||||
self.window = window
|
||||
|
||||
self.providers = providers
|
||||
self.data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
try:
|
||||
self.providers[self.slug]
|
||||
return self.app.data["providers"][self.slug]["data"]
|
||||
except KeyError:
|
||||
self.providers[self.slug] = {
|
||||
self.app.data["providers"][self.slug] = {
|
||||
"enabled": False,
|
||||
"data": {
|
||||
|
||||
}
|
||||
}
|
||||
finally:
|
||||
self.data = self.providers[self.slug]["data"]
|
||||
return self.app.data["providers"][self.slug]["data"]
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.providers[self.slug]["enabled"]
|
||||
return self.app.data["providers"][self.slug]["enabled"]
|
||||
|
||||
def set_enabled(self, status):
|
||||
self.providers[self.slug]["enabled"] = status
|
||||
self.app.data["providers"][self.slug]["enabled"] = status
|
||||
|
||||
def ask(self, prompt, chat):
|
||||
raise NotImplementedError()
|
||||
|
@ -64,3 +68,16 @@ class BaseProvider:
|
|||
prompt = [(prompt[i : i + n]) for i in range(0, len(prompt), n)]
|
||||
return prompt
|
||||
|
||||
def open_documentation(self, *args, **kwargs):
|
||||
GLib.spawn_command_line_async(
|
||||
f"xdg-open {self.url}"
|
||||
)
|
||||
|
||||
def how_to_get_a_token(self):
|
||||
about_button = Gtk.Button()
|
||||
about_button.set_icon_name("dialog-information-symbolic")
|
||||
about_button.set_tooltip_text(_("How to get a token"))
|
||||
about_button.add_css_class("flat")
|
||||
about_button.set_valign(Gtk.Align.CENTER)
|
||||
about_button.connect("clicked", self.open_documentation)
|
||||
return about_button
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
|
||||
class GoogleFlant5XXLProvider(BaseHFChatProvider):
|
||||
name = "Google Flan T5 XXL"
|
||||
description = "A better Text-To-Text Transfer Transformer (T5) model"
|
||||
provider = "google/flan-t5-xxl"
|
||||
chat_mode = False
|
|
@ -0,0 +1,7 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
|
||||
class GPT2Provider(BaseHFChatProvider):
|
||||
name = "GPT 2"
|
||||
description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion"
|
||||
provider = "gpt2"
|
||||
chat_mode = False
|
|
@ -1,10 +1,13 @@
|
|||
from .base import BaseProvider
|
||||
|
||||
import json
|
||||
import requests
|
||||
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
|
||||
|
||||
class BaseHFChatProvider(BaseProvider):
|
||||
provider = None
|
||||
chat_mode = True
|
||||
|
||||
def ask(self, prompt, chat, **kwargs):
|
||||
chat = chat["content"]
|
||||
|
@ -12,18 +15,26 @@ class BaseHFChatProvider(BaseProvider):
|
|||
API_URL = f"https://api-inference.huggingface.co/models/{self.provider}"
|
||||
|
||||
def query(payload):
|
||||
response = requests.post(API_URL, json=payload)
|
||||
if self.data.get('api_key'):
|
||||
headers = {"Authorization": f"Bearer {self.data['api_key']}"}
|
||||
response = requests.post(API_URL, json=payload, headers=headers)
|
||||
else:
|
||||
response = requests.post(API_URL, json=payload)
|
||||
|
||||
return response.json()
|
||||
|
||||
output = query({
|
||||
"inputs": {
|
||||
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
|
||||
"generated_responses": [i['content'] for i in chat if i['role'] == self.app.bot_name],
|
||||
"text": prompt
|
||||
},
|
||||
})
|
||||
|
||||
print(output)
|
||||
if self.chat_mode:
|
||||
output = query({
|
||||
"inputs": {
|
||||
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
|
||||
"generated_responses": [i['content'] for i in chat if i['role'] == self.app.bot_name],
|
||||
"text": prompt
|
||||
},
|
||||
})
|
||||
else:
|
||||
output = query({
|
||||
"inputs": self.make_prompt(prompt, chat),
|
||||
})
|
||||
|
||||
if 'generated_text' in output:
|
||||
return output['generated_text']
|
||||
|
@ -31,3 +42,28 @@ class BaseHFChatProvider(BaseProvider):
|
|||
match output['error']:
|
||||
case "Rate limit reached. Please log in or use your apiToken":
|
||||
return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)")
|
||||
elif isinstance(output, list):
|
||||
if 'generated_text' in output[0]:
|
||||
return output[0]['generated_text']
|
||||
else:
|
||||
return _("Sorry, I don't know what to say!")
|
||||
|
||||
def get_settings_rows(self):
|
||||
self.rows = []
|
||||
|
||||
self.api_row = Adw.PasswordEntryRow()
|
||||
self.api_row.connect("apply", self.on_apply)
|
||||
self.api_row.props.text = self.data.get('api_key') or ""
|
||||
self.api_row.props.title = _("API Key")
|
||||
self.api_row.set_show_apply_button(True)
|
||||
self.api_row.add_suffix(self.how_to_get_a_token())
|
||||
self.rows.append(self.api_row)
|
||||
|
||||
return self.rows
|
||||
|
||||
def on_apply(self, widget):
|
||||
api_key = self.api_row.get_text()
|
||||
self.data["api_key"] = api_key
|
||||
|
||||
def make_prompt(self, prompt, chat):
|
||||
return prompt
|
|
@ -6,8 +6,17 @@ providers_sources = [
|
|||
'blenderbot.py',
|
||||
'catgpt.py',
|
||||
'dialogpt.py',
|
||||
'googleflant5xxl.py',
|
||||
'gpt2.py',
|
||||
'hfbasechat.py',
|
||||
'openai.py',
|
||||
'openaigpt35turbo.py',
|
||||
'openaigpt4.py',
|
||||
'openassistantsft1pythia12b.py',
|
||||
'petals.py',
|
||||
'provider_item.py',
|
||||
'stablebeluga2.py',
|
||||
'robertasquad2.py',
|
||||
]
|
||||
|
||||
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
|
|
@ -0,0 +1,91 @@
|
|||
from .base import BaseProvider
|
||||
|
||||
import openai
|
||||
import socket
|
||||
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
|
||||
|
||||
class BaseOpenAIProvider(BaseProvider):
|
||||
model = None
|
||||
api_key_title = "API Key"
|
||||
chat = openai.ChatCompletion
|
||||
|
||||
def __init__(self, app, window):
|
||||
super().__init__(app, window)
|
||||
|
||||
if self.data.get("api_key"):
|
||||
openai.api_key = self.data["api_key"]
|
||||
if self.data.get("api_base"):
|
||||
openai.api_base = self.data["api_base"]
|
||||
|
||||
def ask(self, prompt, chat):
|
||||
chat = chat["content"]
|
||||
|
||||
if self.data.get("api_key"):
|
||||
openai.api_key = self.data["api_key"]
|
||||
if self.data.get("api_base"):
|
||||
openai.api_base = self.data["api_base"]
|
||||
|
||||
if self.model:
|
||||
prompt = self.chunk(prompt)
|
||||
try:
|
||||
response = self.chat.create(
|
||||
model=self.model,
|
||||
messages=chat,
|
||||
).choices[0].message.content
|
||||
except openai.error.AuthenticationError:
|
||||
return _("Your API key is invalid, please check your preferences.")
|
||||
except openai.error.InvalidRequestError:
|
||||
return _("You don't have access to this model, please check your plan and billing details.")
|
||||
except openai.error.RateLimitError:
|
||||
return _("You exceeded your current quota, please check your plan and billing details.")
|
||||
except openai.error.APIError:
|
||||
return _("I'm having trouble connecting to the API, please check your internet connection.")
|
||||
except socket.gaierror:
|
||||
return _("I'm having trouble connecting to the API, please check your internet connection.")
|
||||
else:
|
||||
return response
|
||||
else:
|
||||
return _("No model selected, you can choose one in preferences")
|
||||
|
||||
|
||||
def get_settings_rows(self):
|
||||
self.rows = []
|
||||
|
||||
|
||||
self.api_row = Adw.PasswordEntryRow()
|
||||
self.api_row.connect("apply", self.on_apply)
|
||||
self.api_row.props.text = openai.api_key or ""
|
||||
self.api_row.props.title = self.api_key_title
|
||||
self.api_row.set_show_apply_button(True)
|
||||
self.api_row.add_suffix(self.how_to_get_a_token())
|
||||
self.rows.append(self.api_row)
|
||||
|
||||
self.api_url_row = Adw.EntryRow()
|
||||
self.api_url_row.connect("apply", self.on_apply)
|
||||
self.api_url_row.props.text = openai.api_base or ""
|
||||
self.api_url_row.props.title = "API Url"
|
||||
self.api_url_row.set_show_apply_button(True)
|
||||
self.api_url_row.add_suffix(self.how_to_get_base_url())
|
||||
self.rows.append(self.api_url_row)
|
||||
|
||||
return self.rows
|
||||
|
||||
def on_apply(self, widget):
|
||||
api_key = self.api_row.get_text()
|
||||
openai.api_key = api_key
|
||||
openai.api_base = self.api_url_row.get_text()
|
||||
|
||||
self.data["api_key"] = openai.api_key
|
||||
self.data["api_base"] = openai.api_base
|
||||
|
||||
|
||||
def how_to_get_base_url(self):
|
||||
about_button = Gtk.Button()
|
||||
about_button.set_icon_name("dialog-information-symbolic")
|
||||
about_button.set_tooltip_text("How to choose base url")
|
||||
about_button.add_css_class("flat")
|
||||
about_button.set_valign(Gtk.Align.CENTER)
|
||||
about_button.connect("clicked", self.open_documentation)
|
||||
return about_button
|
|
@ -0,0 +1,7 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
|
||||
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
|
||||
name = "OpenAI GPT 3.5 Turbo"
|
||||
description = "Most capable GPT-3.5 model and optimized for chat."
|
||||
model = "gpt-3.5-turbo"
|
|
@ -0,0 +1,8 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
|
||||
class OpenAIGPT4Provider(BaseOpenAIProvider):
|
||||
name = "OpenAI GPT 4"
|
||||
model = "gpt-4"
|
||||
description = "More capable than any GPT-3.5 model, able to do more complex tasks, and optimized for chat."
|
||||
api_key_title = "API Key (Require a plan with access to the GPT-4 model)"
|
|
@ -0,0 +1,15 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
|
||||
class HuggingFaceOpenAssistantSFT1PythiaProvider(BaseHFChatProvider):
|
||||
name = "Open-Assistant SFT-1 12B"
|
||||
provider = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"
|
||||
|
||||
def make_prompt(self, prompt, chat):
|
||||
p = ""
|
||||
for i in range(0, len(chat)):
|
||||
if chat[i]['role'] == self.app.bot_name:
|
||||
p += f"<|assistant|>{chat[i]['content']}<|endoftext|>"
|
||||
else:
|
||||
p += f"<|prompter|>{chat[i]['content']}<|endoftext|>"
|
||||
p += f"<|prompter|> {prompt}<|endoftext|>"
|
||||
return p
|
|
@ -0,0 +1,32 @@
|
|||
from .base import BaseProvider
|
||||
|
||||
import json
|
||||
import requests
|
||||
|
||||
class BasePetalsProvider(BaseProvider):
|
||||
provider = None
|
||||
|
||||
API_URL = "https://chat.petals.dev/"
|
||||
ENDPOINT = "/api/v1/generate"
|
||||
|
||||
model = None
|
||||
|
||||
def ask(self, prompt, chat, **kwargs):
|
||||
try:
|
||||
API_URL = self.data["api_url"]
|
||||
except KeyError:
|
||||
API_URL = self.API_URL
|
||||
|
||||
API_URL += self.ENDPOINT
|
||||
|
||||
chat = chat["content"]
|
||||
|
||||
|
||||
r = f"{API_URL}?model={self.model}&do_sample=1&temperature=0.75&top_p=0.9&max_length=1000&inputs={prompt}"
|
||||
|
||||
output = requests.post(r).json()
|
||||
|
||||
if output["ok"]:
|
||||
return output["outputs"]
|
||||
else:
|
||||
return _("I'm sorry, I don't know what to say!")
|
|
@ -24,13 +24,13 @@ class Provider(Adw.ExpanderRow):
|
|||
def setup(self):
|
||||
self.set_title(self.provider.name)
|
||||
self.set_subtitle(self.provider.description)
|
||||
self.enable_switch.set_active(self.provider.providers[self.provider.slug]["enabled"])
|
||||
self.enable_switch.set_active( self.app.data["providers"][self.provider.slug]["enabled"])
|
||||
|
||||
if self.provider.require_authentification:
|
||||
if self.provider.get_settings_rows():
|
||||
self.no_preferences_available.set_visible(False)
|
||||
|
||||
for row in self.provider.get_settings_rows():
|
||||
self.add(row)
|
||||
self.add_row(row)
|
||||
|
||||
# CALLBACKS
|
||||
@Gtk.Template.Callback()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
|
||||
class RobertaSquad2Provider(BaseHFChatProvider):
|
||||
name = "Roberta Squad2"
|
||||
provider = "deepset/roberta-base-squad2"
|
||||
|
||||
def make_prompt(self, prompt, chat):
|
||||
context = ""
|
||||
for message in chat:
|
||||
if chat['role'] == self.app.user_name:
|
||||
context += f" {message['content']}"
|
||||
return {
|
||||
"question": prompt,
|
||||
"context": context
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
from .petals import BasePetalsProvider
|
||||
|
||||
class StableBeluga2Provider(BasePetalsProvider):
|
||||
name = "stable-beluga2"
|
||||
model = "stabilityai/StableBeluga2"
|
|
@ -22,7 +22,6 @@ import locale
|
|||
|
||||
from gi.repository import Gtk, Gio, Adw, GLib
|
||||
from babel.dates import format_date, format_datetime, format_time
|
||||
from babel import Locale
|
||||
|
||||
from bavarder.constants import app_id, build_type, rootdir
|
||||
from bavarder.widgets.thread_item import ThreadItem
|
||||
|
@ -170,7 +169,10 @@ class BavarderWindow(Adw.ApplicationWindow):
|
|||
self.split_view.set_collapsed(True)
|
||||
self.split_view.set_show_content(True)
|
||||
|
||||
self.title.set_title(self.chat["title"])
|
||||
try:
|
||||
self.title.set_title(self.chat["title"])
|
||||
except KeyError:
|
||||
self.title.set_title(_("New chat"))
|
||||
|
||||
if self.content:
|
||||
self.stack.set_visible_child(self.main)
|
||||
|
@ -318,15 +320,21 @@ class BavarderWindow(Adw.ApplicationWindow):
|
|||
if prompt:
|
||||
self.message_entry.get_buffer().set_text("")
|
||||
|
||||
try:
|
||||
self.add_user_item(prompt)
|
||||
except AttributeError: # no chat
|
||||
if not self.chat:
|
||||
print("NEW CHAT")
|
||||
self.on_new_chat_action()
|
||||
row = self.threads_list.get_row_at_index(0)
|
||||
|
||||
# now get the latest row
|
||||
row = self.threads_list.get_row_at_index(len(self.app.data["chats"]) - 1)
|
||||
|
||||
|
||||
self.threads_list.select_row(row)
|
||||
self.threads_row_activated_cb()
|
||||
self.add_user_item(prompt)
|
||||
print(self.chat)
|
||||
|
||||
|
||||
self.add_user_item(prompt)
|
||||
|
||||
|
||||
def thread_run():
|
||||
self.toast = Adw.Toast()
|
||||
|
@ -342,6 +350,9 @@ class BavarderWindow(Adw.ApplicationWindow):
|
|||
self.t.join()
|
||||
self.toast.dismiss()
|
||||
|
||||
if not response:
|
||||
response = _("Sorry, I don't know what to say.")
|
||||
|
||||
self.add_assistant_item(response)
|
||||
|
||||
self.t = KillableThread(target=thread_run)
|
||||
|
|
Loading…
Reference in New Issue