diff --git a/flake.nix b/flake.nix index 645b1df..686f5e8 100644 --- a/flake.nix +++ b/flake.nix @@ -55,6 +55,7 @@ lxml openai pygobject3 + pillow requests ]; diff --git a/src/providers/__init__.py b/src/providers/__init__.py index ed2137d..2a9df78 100644 --- a/src/providers/__init__.py +++ b/src/providers/__init__.py @@ -10,6 +10,7 @@ from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvid from .robertasquad2 import RobertaSquad2Provider from .local import LocalProvider from .aihorde import AIHordeProvider +from .stablediffusion import StableDiffusionProvider PROVIDERS = { AIHordeProvider, @@ -20,7 +21,8 @@ PROVIDERS = { OpenAIGPT4Provider, GoogleFlant5XXLProvider, GPT2Provider, - LocalProvider + LocalProvider, + StableDiffusionProvider, # StableBeluga2Provider, # HuggingFaceOpenAssistantSFT1PythiaProvider, # RobertaSquad2Provider diff --git a/src/providers/basehfimage.py b/src/providers/basehfimage.py new file mode 100644 index 0000000..460a321 --- /dev/null +++ b/src/providers/basehfimage.py @@ -0,0 +1,64 @@ +from .baseimage import BaseImageProvider +import requests +import json +from gi.repository import Gtk, Adw, GLib +from PIL import Image, UnidentifiedImageError +import io + + +class BaseHFImageProvider(BaseImageProvider): + provider = None + + def ask(self, prompt, chat, **kwargs): + chat = chat["content"] + + API_URL = f"https://api-inference.huggingface.co/models/{self.provider}" + + def query(payload): + if self.data.get('api_key'): + headers = {"Authorization": f"Bearer {self.data['api_key']}"} + response = requests.post(API_URL, json=payload, headers=headers) + else: + response = requests.post(API_URL, json=payload) + + if response.status_code == 403: + return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)") + elif response.status_code != 200: + return _("Sorry, I don't know what to say! (Error: {response.status_code})") + + return response.content + + prompt = self.make_prompt(prompt, chat) + output = query({ + "inputs": prompt, + "negative_prompts": "", + }) + + if output: + print(output) + try: + print("IMAGE") + return Image.open(io.BytesIO(output)) + except UnidentifiedImageError: + print("FAILED IMAGE") + return output + + def get_settings_rows(self): + self.rows = [] + + self.api_row = Adw.PasswordEntryRow() + self.api_row.connect("apply", self.on_apply) + self.api_row.props.text = self.data.get('api_key') or "" + self.api_row.props.title = _("API Key") + self.api_row.set_show_apply_button(True) + self.api_row.add_suffix(self.how_to_get_a_token()) + self.rows.append(self.api_row) + + return self.rows + + def on_apply(self, widget): + api_key = self.api_row.get_text() + self.data["api_key"] = api_key + + def make_prompt(self, prompt, chat): + return prompt \ No newline at end of file diff --git a/src/providers/baseimage.py b/src/providers/baseimage.py new file mode 100644 index 0000000..d7eefd9 --- /dev/null +++ b/src/providers/baseimage.py @@ -0,0 +1,10 @@ +from .base import BaseProvider, ProviderType + +import requests + +from gi.repository import Gtk, Adw, GLib + + +class BaseImageProvider(BaseProvider): + provider_type = ProviderType.IMAGE + \ No newline at end of file diff --git a/src/providers/meson.build b/src/providers/meson.build index e178e58..4de5d35 100644 --- a/src/providers/meson.build +++ b/src/providers/meson.build @@ -4,6 +4,8 @@ providers_sources = [ '__init__.py', 'aihorde.py', 'base.py', + 'basehfimage.py', + 'baseimage.py', 'blenderbot.py', 'catgpt.py', 'dialogpt.py', @@ -19,6 +21,7 @@ providers_sources = [ 'provider_item.py', 'stablebeluga2.py', 'robertasquad2.py', + 'stablediffusion.py', ] PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir) \ No newline at end of file diff --git a/src/providers/stablediffusion.py b/src/providers/stablediffusion.py new file mode 100644 index 0000000..7e79655 --- /dev/null +++ b/src/providers/stablediffusion.py @@ -0,0 +1,5 @@ +from .basehfimage import BaseHFImageProvider + +class StableDiffusionProvider(BaseHFImageProvider): + name = "Stable Diffusion" + provider = "stabilityai/stable-diffusion-2-1" \ No newline at end of file diff --git a/src/views/window.py b/src/views/window.py index da50990..bb5a040 100644 --- a/src/views/window.py +++ b/src/views/window.py @@ -19,6 +19,8 @@ from datetime import datetime import locale +import io +import base64 from gi.repository import Gtk, Gio, Adw, GLib from babel.dates import format_date, format_datetime, format_time @@ -392,9 +394,17 @@ class BavarderWindow(Adw.ApplicationWindow): self.toast.dismiss() if not response: - response = _("Sorry, I don't know what to say.") + self.add_assistant_item(_("Sorry, I don't know what to say.")) + else: + if isinstance(response, str): + self.add_assistant_item(response) + else: + buffered = io.BytesIO() + response.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()) + + self.add_assistant_item(img_str.decode("utf-8")) - self.add_assistant_item(response) except AttributeError: self.toast.dismiss() self.add_assistant_item(_("Sorry, I don't know what to say.")) diff --git a/src/widgets/item.py b/src/widgets/item.py index 9a7598a..2fa652c 100644 --- a/src/widgets/item.py +++ b/src/widgets/item.py @@ -1,6 +1,10 @@ from gi.repository import Gtk, Adw, Gio, GLib, Pango, GtkSource, Gdk import re +import io +import base64 + +from PIL import Image, UnidentifiedImageError from bavarder.constants import app_id, rootdir from bavarder.widgets.code_block import CodeBlock @@ -72,63 +76,86 @@ class Item(Gtk.Box): self.content_text = self.item["content"] - self.convert_content_to_pango() - - result = "" - is_code = False - for line in self.content_markup: - if isinstance(line, str): - if "`" in line.strip(): - if is_code: - is_code = False - else: - is_code = True - continue - if is_code or not isinstance(line, str): - label = Gtk.Label() - label.set_use_markup(True) - label.set_wrap(True) - label.set_xalign(0) - label.set_wrap_mode(Pango.WrapMode.WORD) - label.set_markup(result) - label.set_justify(Gtk.Justification.LEFT) - label.set_valign(Gtk.Align.START) - label.set_hexpand(True) - label.set_halign(Gtk.Align.START) - self.content.append(label) - - if not isinstance(line, str): - result = "\n".join(line) - else: - result = line.strip() - - self.content.append(CodeBlock(result)) - result = "" - else: - result += f"{line}\n" - - else: - if not result.strip() == "`": - label = Gtk.Label() - label.set_use_markup(True) - label.set_wrap(True) - label.set_xalign(0) - label.set_wrap_mode(Pango.WrapMode.WORD) - label.set_markup(result) - label.set_justify(Gtk.Justification.LEFT) - label.set_valign(Gtk.Align.START) - label.set_hexpand(True) - label.set_halign(Gtk.Align.START) - self.content.append(label) - - t = self.item["role"].lower() - self.parent = parent self.settings = parent.settings self.app = self.parent.get_application() self.win = self.app.get_active_window() + try: + if not isinstance(self.content_text, Image.Image): + if isinstance(self.content_text, bytes): + image = Image.open(io.BytesIO(self.content_text)) + else: + image = Image.open(io.BytesIO(base64.b64decode(self.content_text))) + else: + image = self.content_text + except Exception: + self.convert_content_to_pango() + + result = "" + is_code = False + for line in self.content_markup: + if isinstance(line, str): + if "`" in line.strip(): + if is_code: + is_code = False + else: + is_code = True + continue + if is_code or not isinstance(line, str): + label = Gtk.Label() + label.set_use_markup(True) + label.set_wrap(True) + label.set_xalign(0) + label.set_wrap_mode(Pango.WrapMode.WORD) + label.set_markup(result) + label.set_justify(Gtk.Justification.LEFT) + label.set_valign(Gtk.Align.START) + label.set_hexpand(True) + label.set_halign(Gtk.Align.START) + self.content.append(label) + + if not isinstance(line, str): + result = "\n".join(line) + else: + result = line.strip() + + self.content.append(CodeBlock(result)) + result = "" + else: + result += f"{line}\n" + + else: + if not result.strip() == "`": + label = Gtk.Label() + label.set_use_markup(True) + label.set_wrap(True) + label.set_xalign(0) + label.set_wrap_mode(Pango.WrapMode.WORD) + label.set_markup(result) + label.set_justify(Gtk.Justification.LEFT) + label.set_valign(Gtk.Align.START) + label.set_hexpand(True) + label.set_halign(Gtk.Align.START) + self.content.append(label) + else: + picture = Gtk.Picture() + picture.set_halign(Gtk.Align.CENTER) + picture.set_can_shrink(True) + picture.set_content_fit(Gtk.ContentFit.FILL) + picture.set_visible(True) + picture.add_css_class("card") + picture.set_margin_start(12) + picture.set_margin_end(12) + #print(self.content.get_width(), self.content.get_height()) + picture.set_size_request(270, 270) + image.save("/tmp/image.png") + picture.set_file(Gio.File.new_for_path("/tmp/image.png")) + self.content.append(picture) + + t = self.item["role"].lower() + if t == self.app.user_name.lower() or t == "user": # User self.message_bubble.add_css_class("message-bubble-user") self.avatar.add_css_class("avatar-user")