diff --git a/flake.nix b/flake.nix
index 645b1df..686f5e8 100644
--- a/flake.nix
+++ b/flake.nix
@@ -55,6 +55,7 @@
lxml
openai
pygobject3
+ pillow
requests
];
diff --git a/src/providers/__init__.py b/src/providers/__init__.py
index ed2137d..2a9df78 100644
--- a/src/providers/__init__.py
+++ b/src/providers/__init__.py
@@ -10,6 +10,7 @@ from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvid
from .robertasquad2 import RobertaSquad2Provider
from .local import LocalProvider
from .aihorde import AIHordeProvider
+from .stablediffusion import StableDiffusionProvider
PROVIDERS = {
AIHordeProvider,
@@ -20,7 +21,8 @@ PROVIDERS = {
OpenAIGPT4Provider,
GoogleFlant5XXLProvider,
GPT2Provider,
- LocalProvider
+ LocalProvider,
+ StableDiffusionProvider,
# StableBeluga2Provider,
# HuggingFaceOpenAssistantSFT1PythiaProvider,
# RobertaSquad2Provider
diff --git a/src/providers/basehfimage.py b/src/providers/basehfimage.py
new file mode 100644
index 0000000..460a321
--- /dev/null
+++ b/src/providers/basehfimage.py
@@ -0,0 +1,64 @@
+from .baseimage import BaseImageProvider
+import requests
+import json
+from gi.repository import Gtk, Adw, GLib
+from PIL import Image, UnidentifiedImageError
+import io
+
+
+class BaseHFImageProvider(BaseImageProvider):
+ provider = None
+
+ def ask(self, prompt, chat, **kwargs):
+ chat = chat["content"]
+
+ API_URL = f"https://api-inference.huggingface.co/models/{self.provider}"
+
+ def query(payload):
+ if self.data.get('api_key'):
+ headers = {"Authorization": f"Bearer {self.data['api_key']}"}
+ response = requests.post(API_URL, json=payload, headers=headers)
+ else:
+ response = requests.post(API_URL, json=payload)
+
+ if response.status_code == 403:
+ return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)")
+ elif response.status_code != 200:
+ return _("Sorry, I don't know what to say! (Error: {response.status_code})")
+
+ return response.content
+
+ prompt = self.make_prompt(prompt, chat)
+ output = query({
+ "inputs": prompt,
+ "negative_prompts": "",
+ })
+
+ if output:
+ print(output)
+ try:
+ print("IMAGE")
+ return Image.open(io.BytesIO(output))
+ except UnidentifiedImageError:
+ print("FAILED IMAGE")
+ return output
+
+ def get_settings_rows(self):
+ self.rows = []
+
+ self.api_row = Adw.PasswordEntryRow()
+ self.api_row.connect("apply", self.on_apply)
+ self.api_row.props.text = self.data.get('api_key') or ""
+ self.api_row.props.title = _("API Key")
+ self.api_row.set_show_apply_button(True)
+ self.api_row.add_suffix(self.how_to_get_a_token())
+ self.rows.append(self.api_row)
+
+ return self.rows
+
+ def on_apply(self, widget):
+ api_key = self.api_row.get_text()
+ self.data["api_key"] = api_key
+
+ def make_prompt(self, prompt, chat):
+ return prompt
\ No newline at end of file
diff --git a/src/providers/baseimage.py b/src/providers/baseimage.py
new file mode 100644
index 0000000..d7eefd9
--- /dev/null
+++ b/src/providers/baseimage.py
@@ -0,0 +1,10 @@
+from .base import BaseProvider, ProviderType
+
+import requests
+
+from gi.repository import Gtk, Adw, GLib
+
+
+class BaseImageProvider(BaseProvider):
+ provider_type = ProviderType.IMAGE
+
\ No newline at end of file
diff --git a/src/providers/meson.build b/src/providers/meson.build
index e178e58..4de5d35 100644
--- a/src/providers/meson.build
+++ b/src/providers/meson.build
@@ -4,6 +4,8 @@ providers_sources = [
'__init__.py',
'aihorde.py',
'base.py',
+ 'basehfimage.py',
+ 'baseimage.py',
'blenderbot.py',
'catgpt.py',
'dialogpt.py',
@@ -19,6 +21,7 @@ providers_sources = [
'provider_item.py',
'stablebeluga2.py',
'robertasquad2.py',
+ 'stablediffusion.py',
]
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
\ No newline at end of file
diff --git a/src/providers/stablediffusion.py b/src/providers/stablediffusion.py
new file mode 100644
index 0000000..7e79655
--- /dev/null
+++ b/src/providers/stablediffusion.py
@@ -0,0 +1,5 @@
+from .basehfimage import BaseHFImageProvider
+
+class StableDiffusionProvider(BaseHFImageProvider):
+ name = "Stable Diffusion"
+ provider = "stabilityai/stable-diffusion-2-1"
\ No newline at end of file
diff --git a/src/views/window.py b/src/views/window.py
index da50990..bb5a040 100644
--- a/src/views/window.py
+++ b/src/views/window.py
@@ -19,6 +19,8 @@
from datetime import datetime
import locale
+import io
+import base64
from gi.repository import Gtk, Gio, Adw, GLib
from babel.dates import format_date, format_datetime, format_time
@@ -392,9 +394,17 @@ class BavarderWindow(Adw.ApplicationWindow):
self.toast.dismiss()
if not response:
- response = _("Sorry, I don't know what to say.")
+ self.add_assistant_item(_("Sorry, I don't know what to say."))
+ else:
+ if isinstance(response, str):
+ self.add_assistant_item(response)
+ else:
+ buffered = io.BytesIO()
+ response.save(buffered, format="JPEG")
+ img_str = base64.b64encode(buffered.getvalue())
+
+ self.add_assistant_item(img_str.decode("utf-8"))
- self.add_assistant_item(response)
except AttributeError:
self.toast.dismiss()
self.add_assistant_item(_("Sorry, I don't know what to say."))
diff --git a/src/widgets/item.py b/src/widgets/item.py
index 9a7598a..2fa652c 100644
--- a/src/widgets/item.py
+++ b/src/widgets/item.py
@@ -1,6 +1,10 @@
from gi.repository import Gtk, Adw, Gio, GLib, Pango, GtkSource, Gdk
import re
+import io
+import base64
+
+from PIL import Image, UnidentifiedImageError
from bavarder.constants import app_id, rootdir
from bavarder.widgets.code_block import CodeBlock
@@ -72,63 +76,86 @@ class Item(Gtk.Box):
self.content_text = self.item["content"]
- self.convert_content_to_pango()
-
- result = ""
- is_code = False
- for line in self.content_markup:
- if isinstance(line, str):
- if "`" in line.strip():
- if is_code:
- is_code = False
- else:
- is_code = True
- continue
- if is_code or not isinstance(line, str):
- label = Gtk.Label()
- label.set_use_markup(True)
- label.set_wrap(True)
- label.set_xalign(0)
- label.set_wrap_mode(Pango.WrapMode.WORD)
- label.set_markup(result)
- label.set_justify(Gtk.Justification.LEFT)
- label.set_valign(Gtk.Align.START)
- label.set_hexpand(True)
- label.set_halign(Gtk.Align.START)
- self.content.append(label)
-
- if not isinstance(line, str):
- result = "\n".join(line)
- else:
- result = line.strip()
-
- self.content.append(CodeBlock(result))
- result = ""
- else:
- result += f"{line}\n"
-
- else:
- if not result.strip() == "`":
- label = Gtk.Label()
- label.set_use_markup(True)
- label.set_wrap(True)
- label.set_xalign(0)
- label.set_wrap_mode(Pango.WrapMode.WORD)
- label.set_markup(result)
- label.set_justify(Gtk.Justification.LEFT)
- label.set_valign(Gtk.Align.START)
- label.set_hexpand(True)
- label.set_halign(Gtk.Align.START)
- self.content.append(label)
-
- t = self.item["role"].lower()
-
self.parent = parent
self.settings = parent.settings
self.app = self.parent.get_application()
self.win = self.app.get_active_window()
+ try:
+ if not isinstance(self.content_text, Image.Image):
+ if isinstance(self.content_text, bytes):
+ image = Image.open(io.BytesIO(self.content_text))
+ else:
+ image = Image.open(io.BytesIO(base64.b64decode(self.content_text)))
+ else:
+ image = self.content_text
+ except Exception:
+ self.convert_content_to_pango()
+
+ result = ""
+ is_code = False
+ for line in self.content_markup:
+ if isinstance(line, str):
+ if "`" in line.strip():
+ if is_code:
+ is_code = False
+ else:
+ is_code = True
+ continue
+ if is_code or not isinstance(line, str):
+ label = Gtk.Label()
+ label.set_use_markup(True)
+ label.set_wrap(True)
+ label.set_xalign(0)
+ label.set_wrap_mode(Pango.WrapMode.WORD)
+ label.set_markup(result)
+ label.set_justify(Gtk.Justification.LEFT)
+ label.set_valign(Gtk.Align.START)
+ label.set_hexpand(True)
+ label.set_halign(Gtk.Align.START)
+ self.content.append(label)
+
+ if not isinstance(line, str):
+ result = "\n".join(line)
+ else:
+ result = line.strip()
+
+ self.content.append(CodeBlock(result))
+ result = ""
+ else:
+ result += f"{line}\n"
+
+ else:
+ if not result.strip() == "`":
+ label = Gtk.Label()
+ label.set_use_markup(True)
+ label.set_wrap(True)
+ label.set_xalign(0)
+ label.set_wrap_mode(Pango.WrapMode.WORD)
+ label.set_markup(result)
+ label.set_justify(Gtk.Justification.LEFT)
+ label.set_valign(Gtk.Align.START)
+ label.set_hexpand(True)
+ label.set_halign(Gtk.Align.START)
+ self.content.append(label)
+ else:
+ picture = Gtk.Picture()
+ picture.set_halign(Gtk.Align.CENTER)
+ picture.set_can_shrink(True)
+ picture.set_content_fit(Gtk.ContentFit.FILL)
+ picture.set_visible(True)
+ picture.add_css_class("card")
+ picture.set_margin_start(12)
+ picture.set_margin_end(12)
+ #print(self.content.get_width(), self.content.get_height())
+ picture.set_size_request(270, 270)
+ image.save("/tmp/image.png")
+ picture.set_file(Gio.File.new_for_path("/tmp/image.png"))
+ self.content.append(picture)
+
+ t = self.item["role"].lower()
+
if t == self.app.user_name.lower() or t == "user": # User
self.message_bubble.add_css_class("message-bubble-user")
self.avatar.add_css_class("avatar-user")