feat: add more providers

This commit is contained in:
0xMRTT 2023-08-11 18:17:03 +02:00
parent fb08eaa9fe
commit b438eb1077
18 changed files with 390 additions and 34 deletions

View File

@ -46,8 +46,8 @@
"sources": [
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/e6/02/a2cff6306177ae6bc73bc0665065de51dfb3b9db7373e122e2735faf0d97/tqdm-4.65.0-py3-none-any.whl",
"sha256": "c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"
"url": "https://files.pythonhosted.org/packages/00/e5/f12a80907d0884e6dff9c16d0c0114d81b8cd07dc3ae54c5e962cc83037e/tqdm-4.66.1-py3-none-any.whl",
"sha256": "d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"
}
]
},
@ -106,6 +106,85 @@
"sha256": "b4246fb7677d3b98f501a39d43396d3cafdc8eadb045f4a31be01863f655c610"
}
]
},
{
"name": "python3-openai",
"buildsystem": "simple",
"build-commands": [
"pip3 install --verbose --exists-action=i --no-index --find-links=\"file://${PWD}\" --prefix=${FLATPAK_DEST} \"openai\" --no-build-isolation"
],
"sources": [
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/d6/12/6fc7c7dcc84e263940e87cbafca17c1ef28f39dae6c0b10f51e4ccc764ee/aiohttp-3.8.5.tar.gz",
"sha256": "b9552ec52cc147dbf1944ac7ac98af7602e51ea2dcd076ed194ca3c0d1c7d0bc"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl",
"sha256": "f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl",
"sha256": "7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/f0/eb/fcb708c7bf5056045e9e98f62b93bd7467eb718b0202e7698eb11d66416c/attrs-23.1.0-py3-none-any.whl",
"sha256": "1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/4c/dd/2234eab22353ffc7d94e8d13177aaa050113286e93e7b40eae01fbf7c3d9/certifi-2023.7.22-py3-none-any.whl",
"sha256": "92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/2a/53/cf0a48de1bdcf6ff6e1c9a023f5f523dfe303e4024f216feac64b6eb7f67/charset-normalizer-3.2.0.tar.gz",
"sha256": "3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/8c/1f/49c96ccc87127682ba900b092863ef7c20302a2144b3185412a08480ca22/frozenlist-1.4.0.tar.gz",
"sha256": "09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/fc/34/3030de6f1370931b9dbb4dad48f6ab1015ab1d32447850b9fc94e60097be/idna-3.4-py3-none-any.whl",
"sha256": "90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/4a/15/bd620f7a6eb9aa5112c4ef93e7031bcd071e0611763d8e17706ef8ba65e0/multidict-6.0.4.tar.gz",
"sha256": "3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/67/78/7588a047e458cb8075a4089d721d7af5e143ff85a2388d4a28c530be0494/openai-0.27.8-py3-none-any.whl",
"sha256": "e0a7c2f7da26bdbe5354b03c6d4b82a2f34bd4458c7a17ae1a7092c3e397e03c"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/70/8e/0e2d847013cb52cd35b38c009bb167a1a26b2ce6cd6965bf26b47bc0bf44/requests-2.31.0-py3-none-any.whl",
"sha256": "58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/00/e5/f12a80907d0884e6dff9c16d0c0114d81b8cd07dc3ae54c5e962cc83037e/tqdm-4.66.1-py3-none-any.whl",
"sha256": "d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/9b/81/62fd61001fa4b9d0df6e31d47ff49cfa9de4af03adecf339c7bc30656b37/urllib3-2.0.4-py3-none-any.whl",
"sha256": "de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"
},
{
"type": "file",
"url": "https://files.pythonhosted.org/packages/5f/3f/04b3c5e57844fb9c034b09c5cb6d2b43de5d64a093c30529fd233e16cf09/yarl-1.9.2.tar.gz",
"sha256": "04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"
}
]
}
]
}

View File

@ -7,4 +7,5 @@ tqdm
charset-normalizer
idna
urllib3
Babel
Babel
openai

View File

@ -129,6 +129,8 @@ class BavarderApplication(Adw.Application):
def save(self):
with open(self.data_path, "w", encoding="utf-8") as f:
for name, d in self.data["providers"].items():
print(d)
self.data = json.dump(self.data, f)
self.settings.set_boolean("local-mode", self.local_mode)
self.settings.set_string("current-provider", self.current_provider)
@ -173,7 +175,7 @@ class BavarderApplication(Adw.Application):
self.providers = {}
for provider in PROVIDERS:
p = provider(self, self.win, self.data["providers"])
p = provider(self, self.win)
self.providers[p.slug] = p

View File

@ -1,9 +1,23 @@
from .blenderbot import BlenderBotProvider
from .catgpt import CatGPTProvider
from .dialogpt import DialoGPTProvider
from .stablebeluga2 import StableBeluga2Provider
from .openaigpt35turbo import OpenAIGPT35TurboProvider
from .googleflant5xxl import GoogleFlant5XXLProvider
from .openaigpt4 import OpenAIGPT4Provider
from .gpt2 import GPT2Provider
from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvider
from .robertasquad2 import RobertaSquad2Provider
PROVIDERS = {
BlenderBotProvider,
CatGPTProvider,
DialoGPTProvider
DialoGPTProvider,
OpenAIGPT35TurboProvider,
OpenAIGPT4Provider,
GoogleFlant5XXLProvider,
GPT2Provider,
# StableBeluga2Provider,
# HuggingFaceOpenAssistantSFT1PythiaProvider,
# RobertaSquad2Provider
}

View File

@ -14,34 +14,38 @@ class BaseProvider:
data: Dict[str, str] = {}
has_auth: bool = False
require_authentification: bool = False
base_url = "https://bavarder.codeberg.page/providers/"
def __init__(self, app, window, providers):
def __init__(self, app, window):
self.slug = self.slugify(self.name)
self.copyright = f"© 2023 {self.developer_name}"
self.url = f"https://bavarder.codeberg.page/providers/{self.slug}"
self.url = f"{self.base_url}{self.slug}"
self.app = app
self.window = window
self.providers = providers
self.data
@property
def data(self):
try:
self.providers[self.slug]
return self.app.data["providers"][self.slug]["data"]
except KeyError:
self.providers[self.slug] = {
self.app.data["providers"][self.slug] = {
"enabled": False,
"data": {
}
}
finally:
self.data = self.providers[self.slug]["data"]
return self.app.data["providers"][self.slug]["data"]
@property
def enabled(self):
return self.providers[self.slug]["enabled"]
return self.app.data["providers"][self.slug]["enabled"]
def set_enabled(self, status):
self.providers[self.slug]["enabled"] = status
self.app.data["providers"][self.slug]["enabled"] = status
def ask(self, prompt, chat):
raise NotImplementedError()
@ -64,3 +68,16 @@ class BaseProvider:
prompt = [(prompt[i : i + n]) for i in range(0, len(prompt), n)]
return prompt
def open_documentation(self, *args, **kwargs):
GLib.spawn_command_line_async(
f"xdg-open {self.url}"
)
def how_to_get_a_token(self):
about_button = Gtk.Button()
about_button.set_icon_name("dialog-information-symbolic")
about_button.set_tooltip_text(_("How to get a token"))
about_button.add_css_class("flat")
about_button.set_valign(Gtk.Align.CENTER)
about_button.connect("clicked", self.open_documentation)
return about_button

View File

@ -0,0 +1,7 @@
from .hfbasechat import BaseHFChatProvider
class GoogleFlant5XXLProvider(BaseHFChatProvider):
name = "Google Flan T5 XXL"
description = "A better Text-To-Text Transfer Transformer (T5) model"
provider = "google/flan-t5-xxl"
chat_mode = False

7
src/providers/gpt2.py Normal file
View File

@ -0,0 +1,7 @@
from .hfbasechat import BaseHFChatProvider
class GPT2Provider(BaseHFChatProvider):
name = "GPT 2"
description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion"
provider = "gpt2"
chat_mode = False

View File

@ -1,10 +1,13 @@
from .base import BaseProvider
import json
import requests
from gi.repository import Gtk, Adw, GLib
class BaseHFChatProvider(BaseProvider):
provider = None
chat_mode = True
def ask(self, prompt, chat, **kwargs):
chat = chat["content"]
@ -12,18 +15,26 @@ class BaseHFChatProvider(BaseProvider):
API_URL = f"https://api-inference.huggingface.co/models/{self.provider}"
def query(payload):
response = requests.post(API_URL, json=payload)
if self.data.get('api_key'):
headers = {"Authorization": f"Bearer {self.data['api_key']}"}
response = requests.post(API_URL, json=payload, headers=headers)
else:
response = requests.post(API_URL, json=payload)
return response.json()
output = query({
"inputs": {
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
"generated_responses": [i['content'] for i in chat if i['role'] == self.app.bot_name],
"text": prompt
},
})
print(output)
if self.chat_mode:
output = query({
"inputs": {
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
"generated_responses": [i['content'] for i in chat if i['role'] == self.app.bot_name],
"text": prompt
},
})
else:
output = query({
"inputs": self.make_prompt(prompt, chat),
})
if 'generated_text' in output:
return output['generated_text']
@ -31,3 +42,28 @@ class BaseHFChatProvider(BaseProvider):
match output['error']:
case "Rate limit reached. Please log in or use your apiToken":
return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)")
elif isinstance(output, list):
if 'generated_text' in output[0]:
return output[0]['generated_text']
else:
return _("Sorry, I don't know what to say!")
def get_settings_rows(self):
self.rows = []
self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply)
self.api_row.props.text = self.data.get('api_key') or ""
self.api_row.props.title = _("API Key")
self.api_row.set_show_apply_button(True)
self.api_row.add_suffix(self.how_to_get_a_token())
self.rows.append(self.api_row)
return self.rows
def on_apply(self, widget):
api_key = self.api_row.get_text()
self.data["api_key"] = api_key
def make_prompt(self, prompt, chat):
return prompt

View File

@ -6,8 +6,17 @@ providers_sources = [
'blenderbot.py',
'catgpt.py',
'dialogpt.py',
'googleflant5xxl.py',
'gpt2.py',
'hfbasechat.py',
'openai.py',
'openaigpt35turbo.py',
'openaigpt4.py',
'openassistantsft1pythia12b.py',
'petals.py',
'provider_item.py',
'stablebeluga2.py',
'robertasquad2.py',
]
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)

91
src/providers/openai.py Normal file
View File

@ -0,0 +1,91 @@
from .base import BaseProvider
import openai
import socket
from gi.repository import Gtk, Adw, GLib
class BaseOpenAIProvider(BaseProvider):
model = None
api_key_title = "API Key"
chat = openai.ChatCompletion
def __init__(self, app, window):
super().__init__(app, window)
if self.data.get("api_key"):
openai.api_key = self.data["api_key"]
if self.data.get("api_base"):
openai.api_base = self.data["api_base"]
def ask(self, prompt, chat):
chat = chat["content"]
if self.data.get("api_key"):
openai.api_key = self.data["api_key"]
if self.data.get("api_base"):
openai.api_base = self.data["api_base"]
if self.model:
prompt = self.chunk(prompt)
try:
response = self.chat.create(
model=self.model,
messages=chat,
).choices[0].message.content
except openai.error.AuthenticationError:
return _("Your API key is invalid, please check your preferences.")
except openai.error.InvalidRequestError:
return _("You don't have access to this model, please check your plan and billing details.")
except openai.error.RateLimitError:
return _("You exceeded your current quota, please check your plan and billing details.")
except openai.error.APIError:
return _("I'm having trouble connecting to the API, please check your internet connection.")
except socket.gaierror:
return _("I'm having trouble connecting to the API, please check your internet connection.")
else:
return response
else:
return _("No model selected, you can choose one in preferences")
def get_settings_rows(self):
self.rows = []
self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply)
self.api_row.props.text = openai.api_key or ""
self.api_row.props.title = self.api_key_title
self.api_row.set_show_apply_button(True)
self.api_row.add_suffix(self.how_to_get_a_token())
self.rows.append(self.api_row)
self.api_url_row = Adw.EntryRow()
self.api_url_row.connect("apply", self.on_apply)
self.api_url_row.props.text = openai.api_base or ""
self.api_url_row.props.title = "API Url"
self.api_url_row.set_show_apply_button(True)
self.api_url_row.add_suffix(self.how_to_get_base_url())
self.rows.append(self.api_url_row)
return self.rows
def on_apply(self, widget):
api_key = self.api_row.get_text()
openai.api_key = api_key
openai.api_base = self.api_url_row.get_text()
self.data["api_key"] = openai.api_key
self.data["api_base"] = openai.api_base
def how_to_get_base_url(self):
about_button = Gtk.Button()
about_button.set_icon_name("dialog-information-symbolic")
about_button.set_tooltip_text("How to choose base url")
about_button.add_css_class("flat")
about_button.set_valign(Gtk.Align.CENTER)
about_button.connect("clicked", self.open_documentation)
return about_button

View File

@ -0,0 +1,7 @@
from .openai import BaseOpenAIProvider
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
name = "OpenAI GPT 3.5 Turbo"
description = "Most capable GPT-3.5 model and optimized for chat."
model = "gpt-3.5-turbo"

View File

@ -0,0 +1,8 @@
from .openai import BaseOpenAIProvider
class OpenAIGPT4Provider(BaseOpenAIProvider):
name = "OpenAI GPT 4"
model = "gpt-4"
description = "More capable than any GPT-3.5 model, able to do more complex tasks, and optimized for chat."
api_key_title = "API Key (Require a plan with access to the GPT-4 model)"

View File

@ -0,0 +1,15 @@
from .hfbasechat import BaseHFChatProvider
class HuggingFaceOpenAssistantSFT1PythiaProvider(BaseHFChatProvider):
name = "Open-Assistant SFT-1 12B"
provider = "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"
def make_prompt(self, prompt, chat):
p = ""
for i in range(0, len(chat)):
if chat[i]['role'] == self.app.bot_name:
p += f"<|assistant|>{chat[i]['content']}<|endoftext|>"
else:
p += f"<|prompter|>{chat[i]['content']}<|endoftext|>"
p += f"<|prompter|> {prompt}<|endoftext|>"
return p

32
src/providers/petals.py Normal file
View File

@ -0,0 +1,32 @@
from .base import BaseProvider
import json
import requests
class BasePetalsProvider(BaseProvider):
provider = None
API_URL = "https://chat.petals.dev/"
ENDPOINT = "/api/v1/generate"
model = None
def ask(self, prompt, chat, **kwargs):
try:
API_URL = self.data["api_url"]
except KeyError:
API_URL = self.API_URL
API_URL += self.ENDPOINT
chat = chat["content"]
r = f"{API_URL}?model={self.model}&do_sample=1&temperature=0.75&top_p=0.9&max_length=1000&inputs={prompt}"
output = requests.post(r).json()
if output["ok"]:
return output["outputs"]
else:
return _("I'm sorry, I don't know what to say!")

View File

@ -24,13 +24,13 @@ class Provider(Adw.ExpanderRow):
def setup(self):
self.set_title(self.provider.name)
self.set_subtitle(self.provider.description)
self.enable_switch.set_active(self.provider.providers[self.provider.slug]["enabled"])
self.enable_switch.set_active( self.app.data["providers"][self.provider.slug]["enabled"])
if self.provider.require_authentification:
if self.provider.get_settings_rows():
self.no_preferences_available.set_visible(False)
for row in self.provider.get_settings_rows():
self.add(row)
self.add_row(row)
# CALLBACKS
@Gtk.Template.Callback()

View File

@ -0,0 +1,15 @@
from .hfbasechat import BaseHFChatProvider
class RobertaSquad2Provider(BaseHFChatProvider):
name = "Roberta Squad2"
provider = "deepset/roberta-base-squad2"
def make_prompt(self, prompt, chat):
context = ""
for message in chat:
if chat['role'] == self.app.user_name:
context += f" {message['content']}"
return {
"question": prompt,
"context": context
}

View File

@ -0,0 +1,5 @@
from .petals import BasePetalsProvider
class StableBeluga2Provider(BasePetalsProvider):
name = "stable-beluga2"
model = "stabilityai/StableBeluga2"

View File

@ -22,7 +22,6 @@ import locale
from gi.repository import Gtk, Gio, Adw, GLib
from babel.dates import format_date, format_datetime, format_time
from babel import Locale
from bavarder.constants import app_id, build_type, rootdir
from bavarder.widgets.thread_item import ThreadItem
@ -170,7 +169,10 @@ class BavarderWindow(Adw.ApplicationWindow):
self.split_view.set_collapsed(True)
self.split_view.set_show_content(True)
self.title.set_title(self.chat["title"])
try:
self.title.set_title(self.chat["title"])
except KeyError:
self.title.set_title(_("New chat"))
if self.content:
self.stack.set_visible_child(self.main)
@ -318,15 +320,21 @@ class BavarderWindow(Adw.ApplicationWindow):
if prompt:
self.message_entry.get_buffer().set_text("")
try:
self.add_user_item(prompt)
except AttributeError: # no chat
if not self.chat:
print("NEW CHAT")
self.on_new_chat_action()
row = self.threads_list.get_row_at_index(0)
# now get the latest row
row = self.threads_list.get_row_at_index(len(self.app.data["chats"]) - 1)
self.threads_list.select_row(row)
self.threads_row_activated_cb()
self.add_user_item(prompt)
print(self.chat)
self.add_user_item(prompt)
def thread_run():
self.toast = Adw.Toast()
@ -342,6 +350,9 @@ class BavarderWindow(Adw.ApplicationWindow):
self.t.join()
self.toast.dismiss()
if not response:
response = _("Sorry, I don't know what to say.")
self.add_assistant_item(response)
self.t = KillableThread(target=thread_run)