Add image providers

This commit is contained in:
0xmrtt 2024-02-25 12:38:39 +01:00
parent 02b02edbbb
commit 7456becf9e
8 changed files with 176 additions and 54 deletions

View File

@ -55,6 +55,7 @@
lxml
openai
pygobject3
pillow
requests
];

View File

@ -10,6 +10,7 @@ from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvid
from .robertasquad2 import RobertaSquad2Provider
from .local import LocalProvider
from .aihorde import AIHordeProvider
from .stablediffusion import StableDiffusionProvider
PROVIDERS = {
AIHordeProvider,
@ -20,7 +21,8 @@ PROVIDERS = {
OpenAIGPT4Provider,
GoogleFlant5XXLProvider,
GPT2Provider,
LocalProvider
LocalProvider,
StableDiffusionProvider,
# StableBeluga2Provider,
# HuggingFaceOpenAssistantSFT1PythiaProvider,
# RobertaSquad2Provider

View File

@ -0,0 +1,64 @@
from .baseimage import BaseImageProvider
import requests
import json
from gi.repository import Gtk, Adw, GLib
from PIL import Image, UnidentifiedImageError
import io
class BaseHFImageProvider(BaseImageProvider):
provider = None
def ask(self, prompt, chat, **kwargs):
chat = chat["content"]
API_URL = f"https://api-inference.huggingface.co/models/{self.provider}"
def query(payload):
if self.data.get('api_key'):
headers = {"Authorization": f"Bearer {self.data['api_key']}"}
response = requests.post(API_URL, json=payload, headers=headers)
else:
response = requests.post(API_URL, json=payload)
if response.status_code == 403:
return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)")
elif response.status_code != 200:
return _("Sorry, I don't know what to say! (Error: {response.status_code})")
return response.content
prompt = self.make_prompt(prompt, chat)
output = query({
"inputs": prompt,
"negative_prompts": "",
})
if output:
print(output)
try:
print("IMAGE")
return Image.open(io.BytesIO(output))
except UnidentifiedImageError:
print("FAILED IMAGE")
return output
def get_settings_rows(self):
self.rows = []
self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply)
self.api_row.props.text = self.data.get('api_key') or ""
self.api_row.props.title = _("API Key")
self.api_row.set_show_apply_button(True)
self.api_row.add_suffix(self.how_to_get_a_token())
self.rows.append(self.api_row)
return self.rows
def on_apply(self, widget):
api_key = self.api_row.get_text()
self.data["api_key"] = api_key
def make_prompt(self, prompt, chat):
return prompt

View File

@ -0,0 +1,10 @@
from .base import BaseProvider, ProviderType
import requests
from gi.repository import Gtk, Adw, GLib
class BaseImageProvider(BaseProvider):
provider_type = ProviderType.IMAGE

View File

@ -4,6 +4,8 @@ providers_sources = [
'__init__.py',
'aihorde.py',
'base.py',
'basehfimage.py',
'baseimage.py',
'blenderbot.py',
'catgpt.py',
'dialogpt.py',
@ -19,6 +21,7 @@ providers_sources = [
'provider_item.py',
'stablebeluga2.py',
'robertasquad2.py',
'stablediffusion.py',
]
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)

View File

@ -0,0 +1,5 @@
from .basehfimage import BaseHFImageProvider
class StableDiffusionProvider(BaseHFImageProvider):
name = "Stable Diffusion"
provider = "stabilityai/stable-diffusion-2-1"

View File

@ -19,6 +19,8 @@
from datetime import datetime
import locale
import io
import base64
from gi.repository import Gtk, Gio, Adw, GLib
from babel.dates import format_date, format_datetime, format_time
@ -392,9 +394,17 @@ class BavarderWindow(Adw.ApplicationWindow):
self.toast.dismiss()
if not response:
response = _("Sorry, I don't know what to say.")
self.add_assistant_item(_("Sorry, I don't know what to say."))
else:
if isinstance(response, str):
self.add_assistant_item(response)
else:
buffered = io.BytesIO()
response.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
self.add_assistant_item(img_str.decode("utf-8"))
self.add_assistant_item(response)
except AttributeError:
self.toast.dismiss()
self.add_assistant_item(_("Sorry, I don't know what to say."))

View File

@ -1,6 +1,10 @@
from gi.repository import Gtk, Adw, Gio, GLib, Pango, GtkSource, Gdk
import re
import io
import base64
from PIL import Image, UnidentifiedImageError
from bavarder.constants import app_id, rootdir
from bavarder.widgets.code_block import CodeBlock
@ -72,63 +76,86 @@ class Item(Gtk.Box):
self.content_text = self.item["content"]
self.convert_content_to_pango()
result = ""
is_code = False
for line in self.content_markup:
if isinstance(line, str):
if "<tt></tt>`" in line.strip():
if is_code:
is_code = False
else:
is_code = True
continue
if is_code or not isinstance(line, str):
label = Gtk.Label()
label.set_use_markup(True)
label.set_wrap(True)
label.set_xalign(0)
label.set_wrap_mode(Pango.WrapMode.WORD)
label.set_markup(result)
label.set_justify(Gtk.Justification.LEFT)
label.set_valign(Gtk.Align.START)
label.set_hexpand(True)
label.set_halign(Gtk.Align.START)
self.content.append(label)
if not isinstance(line, str):
result = "\n".join(line)
else:
result = line.strip()
self.content.append(CodeBlock(result))
result = ""
else:
result += f"{line}\n"
else:
if not result.strip() == "<tt></tt>`":
label = Gtk.Label()
label.set_use_markup(True)
label.set_wrap(True)
label.set_xalign(0)
label.set_wrap_mode(Pango.WrapMode.WORD)
label.set_markup(result)
label.set_justify(Gtk.Justification.LEFT)
label.set_valign(Gtk.Align.START)
label.set_hexpand(True)
label.set_halign(Gtk.Align.START)
self.content.append(label)
t = self.item["role"].lower()
self.parent = parent
self.settings = parent.settings
self.app = self.parent.get_application()
self.win = self.app.get_active_window()
try:
if not isinstance(self.content_text, Image.Image):
if isinstance(self.content_text, bytes):
image = Image.open(io.BytesIO(self.content_text))
else:
image = Image.open(io.BytesIO(base64.b64decode(self.content_text)))
else:
image = self.content_text
except Exception:
self.convert_content_to_pango()
result = ""
is_code = False
for line in self.content_markup:
if isinstance(line, str):
if "<tt></tt>`" in line.strip():
if is_code:
is_code = False
else:
is_code = True
continue
if is_code or not isinstance(line, str):
label = Gtk.Label()
label.set_use_markup(True)
label.set_wrap(True)
label.set_xalign(0)
label.set_wrap_mode(Pango.WrapMode.WORD)
label.set_markup(result)
label.set_justify(Gtk.Justification.LEFT)
label.set_valign(Gtk.Align.START)
label.set_hexpand(True)
label.set_halign(Gtk.Align.START)
self.content.append(label)
if not isinstance(line, str):
result = "\n".join(line)
else:
result = line.strip()
self.content.append(CodeBlock(result))
result = ""
else:
result += f"{line}\n"
else:
if not result.strip() == "<tt></tt>`":
label = Gtk.Label()
label.set_use_markup(True)
label.set_wrap(True)
label.set_xalign(0)
label.set_wrap_mode(Pango.WrapMode.WORD)
label.set_markup(result)
label.set_justify(Gtk.Justification.LEFT)
label.set_valign(Gtk.Align.START)
label.set_hexpand(True)
label.set_halign(Gtk.Align.START)
self.content.append(label)
else:
picture = Gtk.Picture()
picture.set_halign(Gtk.Align.CENTER)
picture.set_can_shrink(True)
picture.set_content_fit(Gtk.ContentFit.FILL)
picture.set_visible(True)
picture.add_css_class("card")
picture.set_margin_start(12)
picture.set_margin_end(12)
#print(self.content.get_width(), self.content.get_height())
picture.set_size_request(270, 270)
image.save("/tmp/image.png")
picture.set_file(Gio.File.new_for_path("/tmp/image.png"))
self.content.append(picture)
t = self.item["role"].lower()
if t == self.app.user_name.lower() or t == "user": # User
self.message_bubble.add_css_class("message-bubble-user")
self.avatar.add_css_class("avatar-user")