diff --git a/data/icons/hicolor/scalable/actions/object-select-symbolic.svg b/data/icons/hicolor/scalable/actions/object-select-symbolic.svg
new file mode 100644
index 0000000..2fd2f94
--- /dev/null
+++ b/data/icons/hicolor/scalable/actions/object-select-symbolic.svg
@@ -0,0 +1,3 @@
+
diff --git a/po/Bavarder.pot b/po/Bavarder.pot
index 720dfb9..736f27f 100644
--- a/po/Bavarder.pot
+++ b/po/Bavarder.pot
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2023-08-21 12:44+0200\n"
+"POT-Creation-Date: 2023-08-22 17:12+0200\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language-Team: LANGUAGE \n"
@@ -37,11 +37,11 @@ msgid ""
"help/huggingface/)"
msgstr ""
-#: src/providers/hfbasechat.py:49
+#: src/providers/hfbasechat.py:50
msgid "Sorry, I don't know what to say!"
msgstr ""
-#: src/providers/hfbasechat.py:57
+#: src/providers/hfbasechat.py:58
msgid "API Key"
msgstr ""
@@ -106,7 +106,7 @@ msgid "Cancel"
msgstr ""
#: src/views/preferences_window.py:78 src/views/window.py:218
-#: src/widgets/thread_item.blp:56 src/widgets/thread_item.py:108
+#: src/widgets/thread_item.blp:58 src/widgets/thread_item.py:108
msgid "Delete"
msgstr ""
@@ -398,7 +398,7 @@ msgstr ""
msgid "Edit Title"
msgstr ""
-#: src/widgets/thread_item.blp:61
+#: src/widgets/thread_item.blp:63
msgid "Star"
msgstr ""
@@ -418,16 +418,16 @@ msgstr ""
msgid "Thread Deleted"
msgstr ""
-#: src/main.py:261
+#: src/main.py:265
msgid ""
"Please download a model from Preferences by clicking on the Dot Menu at the "
"top!"
msgstr ""
-#: src/main.py:265
+#: src/main.py:269
msgid "Hello, I am Bavarder, a Chit-Chat AI"
msgstr ""
-#: src/main.py:288
+#: src/main.py:292
msgid "Please enable a provider from the Dot Menu"
msgstr ""
diff --git a/po/LINGUAS b/po/LINGUAS
index 1a9e7c9..83900f6 100644
--- a/po/LINGUAS
+++ b/po/LINGUAS
@@ -8,6 +8,7 @@ fa
fi
fr
gl
+he
hu
it
nl
@@ -21,4 +22,3 @@ tr
uk
zh_Hans
zh_Hant
-he
diff --git a/src/bavarder.gresource.xml b/src/bavarder.gresource.xml
index ea6b3db..6081226 100644
--- a/src/bavarder.gresource.xml
+++ b/src/bavarder.gresource.xml
@@ -26,6 +26,7 @@
../data/icons/hicolor/scalable/actions/cloud-filled-symbolic.svg
../data/icons/hicolor/scalable/actions/document-edit-symbolic.svg
../data/icons/hicolor/scalable/actions/go-bottom-symbolic.svg
+ ../data/icons/hicolor/scalable/actions/object-select-symbolic.svg
../data/icons/hicolor/scalable/actions/paper-plane-symbolic.svg
../data/icons/hicolor/scalable/actions/settings-symbolic.svg
../data/icons/hicolor/scalable/actions/terminal-symbolic.svg
diff --git a/src/main.py b/src/main.py
index 2176e33..70d0a34 100644
--- a/src/main.py
+++ b/src/main.py
@@ -84,7 +84,11 @@ class BavarderApplication(Adw.Application):
self.data = {
"chats": [],
- "providers": {},
+ "providers": {
+ "google-flan-t5-xxl": {"enabled": True, "data": {}},
+ "gpt-2": {"enabled": True, "data": {}},
+
+ },
"models": {}
}
@@ -216,8 +220,8 @@ class BavarderApplication(Adw.Application):
def on_preferences_action(self, widget, _):
"""Callback for the app.preferences action."""
- preferences = PreferencesWindow(self.win)
- preferences.present()
+ self.preferences_window = PreferencesWindow(self.win)
+ self.preferences_window.present()
def create_action(self, name, callback, shortcuts=None):
diff --git a/src/providers/__init__.py b/src/providers/__init__.py
index 7d86121..ed2137d 100644
--- a/src/providers/__init__.py
+++ b/src/providers/__init__.py
@@ -9,8 +9,10 @@ from .gpt2 import GPT2Provider
from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvider
from .robertasquad2 import RobertaSquad2Provider
from .local import LocalProvider
+from .aihorde import AIHordeProvider
PROVIDERS = {
+ AIHordeProvider,
BlenderBotProvider,
CatGPTProvider,
DialoGPTProvider,
diff --git a/src/providers/aihorde.py b/src/providers/aihorde.py
new file mode 100644
index 0000000..5d19f46
--- /dev/null
+++ b/src/providers/aihorde.py
@@ -0,0 +1,154 @@
+from .base import BaseProvider
+
+import json
+import requests
+import time
+
+from gi.repository import Adw, Gtk
+
+class AIHordeProvider(BaseProvider):
+ name = "AI Horde"
+
+ ASYNC_URL = "https://stablehorde.net/api/v2/generate/text/async"
+ STATUS_URL = "https://stablehorde.net/api/v2/generate/text/status/"
+ API_KEY = "0000000000"
+ model = "PygmalionAI/pygmalion-7b"
+
+
+ def ask(self, prompt, chat, **kwargs):
+ self.API_KEY = self.data.get("api_key", "0000000000")
+
+ chat = chat["content"]
+
+ self.headers = {
+ "Client-Agent": "bavarder:1:linux",
+ "apikey": self.API_KEY,
+ }
+
+ data = {
+ "prompt": prompt,
+ "models": [
+ self.model
+ ]
+ }
+
+ r = requests.post(self.ASYNC_URL, json=data, headers=self.headers)
+
+ if r.status_code == 202:
+ rid = r.json()["id"]
+ else:
+ print(r.json())
+ print(r.status_code)
+ return _("I'm sorry, I don't know what to say!")
+
+
+ # do the request every seconds and check if it's finished
+ while True:
+ r = self.check_status(rid)
+ if r:
+ return r
+ else:
+ time.sleep(1)
+ return _("I'm sorry, I don't know what to say!")
+
+ def check_status(self, rid):
+ r = requests.get(self.STATUS_URL + rid)
+ rj = r.json()
+
+ if r.status_code == 200:
+ print(rj)
+ if rj["done"]:
+ return r.json()["generations"][0]["text"]
+ return None
+
+ def get_settings_rows(self):
+ self.rows = []
+
+ self.api_row = Adw.PasswordEntryRow()
+ self.api_row.connect("apply", self.on_apply)
+ self.api_row.props.text = self.data.get('api_key') or self.API_KEY
+ self.api_row.props.title = _("API Key")
+ self.api_row.set_show_apply_button(True)
+ self.api_row.add_suffix(self.how_to_get_a_token())
+ self.rows.append(self.api_row)
+
+ r = requests.get("https://stablehorde.net/api/v2/status/models?type=text")
+
+ if r.status_code != 200:
+ print(r.json())
+ return self.rows
+ else:
+ rj = r.json()
+
+ models_row = Adw.ActionRow()
+ models_row.set_title(_("Models"))
+ models_row.set_subtitle(_("Select a model to use"))
+
+ go_to_sub_button = Gtk.Button.new_from_icon_name("go-next-symbolic")
+ go_to_sub_button.set_valign(Gtk.Align.CENTER)
+ go_to_sub_button.set_tooltip_text(_("Go to the models page"))
+ go_to_sub_button.add_css_class("flat")
+ go_to_sub_button.connect("clicked", self.open_subpage)
+
+ models_row.add_suffix(go_to_sub_button)
+
+ self.page = Adw.NavigationPage()
+
+ prefpage = Adw.PreferencesPage()
+
+ group = Adw.PreferencesGroup()
+
+ self.selected_row = Adw.ActionRow()
+ self.selected_row.set_title(_("Selected model"))
+ if self.model:
+ self.selected_row.set_subtitle(self.model)
+ else:
+ self.selected_row.set_subtitle(_("No model selected"))
+
+ group.add(self.selected_row)
+
+ for model in rj:
+ mr = Adw.ActionRow()
+ mr.props.title = model["name"]
+ mr.props.subtitle = f"Performance {model['performance']} - Jobs {model['jobs']} - Queued {model['queued']}"
+
+ apply_button = Gtk.Button.new_from_icon_name("object-select-symbolic")
+ apply_button.connect("clicked", self.on_apply_model, model["name"])
+ apply_button.set_valign(Gtk.Align.CENTER)
+ apply_button.set_tooltip_text(_("Select this model"))
+ apply_button.add_css_class("flat")
+
+ mr.add_suffix(apply_button)
+
+ group.add(mr)
+
+
+ toolbar = Adw.ToolbarView()
+ header = Adw.HeaderBar()
+ label = Gtk.Label()
+ label.set_label(_("Models"))
+ header.set_title_widget(label)
+ toolbar.add_top_bar(header)
+ prefpage.add(group)
+ toolbar.set_content(prefpage)
+ self.page.set_child(toolbar)
+
+
+
+ self.rows.append(models_row)
+
+ return self.rows
+
+ def open_subpage(self, widget):
+ self.app.preferences_window.push_subpage(self.page)
+
+ def on_apply(self, widget):
+ self.API_KEY = self.api_row.get_text()
+ self.data["api_key"] = self.API_KEY
+
+ def on_apply_model(self, widget, name):
+ self.model = name
+ if self.model:
+ self.selected_row.set_subtitle(self.model)
+ else:
+ self.selected_row.set_subtitle(_("No model selected"))
diff --git a/src/providers/hfbasechat.py b/src/providers/hfbasechat.py
index 38a1cb4..28bd049 100644
--- a/src/providers/hfbasechat.py
+++ b/src/providers/hfbasechat.py
@@ -46,6 +46,7 @@ class BaseHFChatProvider(BaseProvider):
if 'generated_text' in output[0]:
return output[0]['generated_text']
else:
+ print(output)
return _("Sorry, I don't know what to say!")
def get_settings_rows(self):
diff --git a/src/providers/meson.build b/src/providers/meson.build
index 0062b8d..e178e58 100644
--- a/src/providers/meson.build
+++ b/src/providers/meson.build
@@ -2,6 +2,7 @@ providers_dir = join_paths(MODULE_DIR, 'providers')
providers_sources = [
'__init__.py',
+ 'aihorde.py',
'base.py',
'blenderbot.py',
'catgpt.py',