Add image providers
This commit is contained in:
parent
02b02edbbb
commit
7456becf9e
|
@ -10,6 +10,7 @@ from .openassistantsft1pythia12b import HuggingFaceOpenAssistantSFT1PythiaProvid
|
|||
from .robertasquad2 import RobertaSquad2Provider
|
||||
from .local import LocalProvider
|
||||
from .aihorde import AIHordeProvider
|
||||
from .stablediffusion import StableDiffusionProvider
|
||||
|
||||
PROVIDERS = {
|
||||
AIHordeProvider,
|
||||
|
@ -20,7 +21,8 @@ PROVIDERS = {
|
|||
OpenAIGPT4Provider,
|
||||
GoogleFlant5XXLProvider,
|
||||
GPT2Provider,
|
||||
LocalProvider
|
||||
LocalProvider,
|
||||
StableDiffusionProvider,
|
||||
# StableBeluga2Provider,
|
||||
# HuggingFaceOpenAssistantSFT1PythiaProvider,
|
||||
# RobertaSquad2Provider
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
from .baseimage import BaseImageProvider
|
||||
import requests
|
||||
import json
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
import io
|
||||
|
||||
|
||||
class BaseHFImageProvider(BaseImageProvider):
|
||||
provider = None
|
||||
|
||||
def ask(self, prompt, chat, **kwargs):
|
||||
chat = chat["content"]
|
||||
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{self.provider}"
|
||||
|
||||
def query(payload):
|
||||
if self.data.get('api_key'):
|
||||
headers = {"Authorization": f"Bearer {self.data['api_key']}"}
|
||||
response = requests.post(API_URL, json=payload, headers=headers)
|
||||
else:
|
||||
response = requests.post(API_URL, json=payload)
|
||||
|
||||
if response.status_code == 403:
|
||||
return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)")
|
||||
elif response.status_code != 200:
|
||||
return _("Sorry, I don't know what to say! (Error: {response.status_code})")
|
||||
|
||||
return response.content
|
||||
|
||||
prompt = self.make_prompt(prompt, chat)
|
||||
output = query({
|
||||
"inputs": prompt,
|
||||
"negative_prompts": "",
|
||||
})
|
||||
|
||||
if output:
|
||||
print(output)
|
||||
try:
|
||||
print("IMAGE")
|
||||
return Image.open(io.BytesIO(output))
|
||||
except UnidentifiedImageError:
|
||||
print("FAILED IMAGE")
|
||||
return output
|
||||
|
||||
def get_settings_rows(self):
|
||||
self.rows = []
|
||||
|
||||
self.api_row = Adw.PasswordEntryRow()
|
||||
self.api_row.connect("apply", self.on_apply)
|
||||
self.api_row.props.text = self.data.get('api_key') or ""
|
||||
self.api_row.props.title = _("API Key")
|
||||
self.api_row.set_show_apply_button(True)
|
||||
self.api_row.add_suffix(self.how_to_get_a_token())
|
||||
self.rows.append(self.api_row)
|
||||
|
||||
return self.rows
|
||||
|
||||
def on_apply(self, widget):
|
||||
api_key = self.api_row.get_text()
|
||||
self.data["api_key"] = api_key
|
||||
|
||||
def make_prompt(self, prompt, chat):
|
||||
return prompt
|
|
@ -0,0 +1,10 @@
|
|||
from .base import BaseProvider, ProviderType
|
||||
|
||||
import requests
|
||||
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
|
||||
|
||||
class BaseImageProvider(BaseProvider):
|
||||
provider_type = ProviderType.IMAGE
|
||||
|
|
@ -4,6 +4,8 @@ providers_sources = [
|
|||
'__init__.py',
|
||||
'aihorde.py',
|
||||
'base.py',
|
||||
'basehfimage.py',
|
||||
'baseimage.py',
|
||||
'blenderbot.py',
|
||||
'catgpt.py',
|
||||
'dialogpt.py',
|
||||
|
@ -19,6 +21,7 @@ providers_sources = [
|
|||
'provider_item.py',
|
||||
'stablebeluga2.py',
|
||||
'robertasquad2.py',
|
||||
'stablediffusion.py',
|
||||
]
|
||||
|
||||
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
|
|
@ -0,0 +1,5 @@
|
|||
from .basehfimage import BaseHFImageProvider
|
||||
|
||||
class StableDiffusionProvider(BaseHFImageProvider):
|
||||
name = "Stable Diffusion"
|
||||
provider = "stabilityai/stable-diffusion-2-1"
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
from datetime import datetime
|
||||
import locale
|
||||
import io
|
||||
import base64
|
||||
|
||||
from gi.repository import Gtk, Gio, Adw, GLib
|
||||
from babel.dates import format_date, format_datetime, format_time
|
||||
|
@ -392,9 +394,17 @@ class BavarderWindow(Adw.ApplicationWindow):
|
|||
self.toast.dismiss()
|
||||
|
||||
if not response:
|
||||
response = _("Sorry, I don't know what to say.")
|
||||
self.add_assistant_item(_("Sorry, I don't know what to say."))
|
||||
else:
|
||||
if isinstance(response, str):
|
||||
self.add_assistant_item(response)
|
||||
else:
|
||||
buffered = io.BytesIO()
|
||||
response.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
|
||||
self.add_assistant_item(img_str.decode("utf-8"))
|
||||
|
||||
self.add_assistant_item(response)
|
||||
except AttributeError:
|
||||
self.toast.dismiss()
|
||||
self.add_assistant_item(_("Sorry, I don't know what to say."))
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from gi.repository import Gtk, Adw, Gio, GLib, Pango, GtkSource, Gdk
|
||||
|
||||
import re
|
||||
import io
|
||||
import base64
|
||||
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
from bavarder.constants import app_id, rootdir
|
||||
from bavarder.widgets.code_block import CodeBlock
|
||||
|
@ -72,63 +76,86 @@ class Item(Gtk.Box):
|
|||
|
||||
self.content_text = self.item["content"]
|
||||
|
||||
self.convert_content_to_pango()
|
||||
|
||||
result = ""
|
||||
is_code = False
|
||||
for line in self.content_markup:
|
||||
if isinstance(line, str):
|
||||
if "<tt></tt>`" in line.strip():
|
||||
if is_code:
|
||||
is_code = False
|
||||
else:
|
||||
is_code = True
|
||||
continue
|
||||
if is_code or not isinstance(line, str):
|
||||
label = Gtk.Label()
|
||||
label.set_use_markup(True)
|
||||
label.set_wrap(True)
|
||||
label.set_xalign(0)
|
||||
label.set_wrap_mode(Pango.WrapMode.WORD)
|
||||
label.set_markup(result)
|
||||
label.set_justify(Gtk.Justification.LEFT)
|
||||
label.set_valign(Gtk.Align.START)
|
||||
label.set_hexpand(True)
|
||||
label.set_halign(Gtk.Align.START)
|
||||
self.content.append(label)
|
||||
|
||||
if not isinstance(line, str):
|
||||
result = "\n".join(line)
|
||||
else:
|
||||
result = line.strip()
|
||||
|
||||
self.content.append(CodeBlock(result))
|
||||
result = ""
|
||||
else:
|
||||
result += f"{line}\n"
|
||||
|
||||
else:
|
||||
if not result.strip() == "<tt></tt>`":
|
||||
label = Gtk.Label()
|
||||
label.set_use_markup(True)
|
||||
label.set_wrap(True)
|
||||
label.set_xalign(0)
|
||||
label.set_wrap_mode(Pango.WrapMode.WORD)
|
||||
label.set_markup(result)
|
||||
label.set_justify(Gtk.Justification.LEFT)
|
||||
label.set_valign(Gtk.Align.START)
|
||||
label.set_hexpand(True)
|
||||
label.set_halign(Gtk.Align.START)
|
||||
self.content.append(label)
|
||||
|
||||
t = self.item["role"].lower()
|
||||
|
||||
self.parent = parent
|
||||
self.settings = parent.settings
|
||||
|
||||
self.app = self.parent.get_application()
|
||||
self.win = self.app.get_active_window()
|
||||
|
||||
try:
|
||||
if not isinstance(self.content_text, Image.Image):
|
||||
if isinstance(self.content_text, bytes):
|
||||
image = Image.open(io.BytesIO(self.content_text))
|
||||
else:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(self.content_text)))
|
||||
else:
|
||||
image = self.content_text
|
||||
except Exception:
|
||||
self.convert_content_to_pango()
|
||||
|
||||
result = ""
|
||||
is_code = False
|
||||
for line in self.content_markup:
|
||||
if isinstance(line, str):
|
||||
if "<tt></tt>`" in line.strip():
|
||||
if is_code:
|
||||
is_code = False
|
||||
else:
|
||||
is_code = True
|
||||
continue
|
||||
if is_code or not isinstance(line, str):
|
||||
label = Gtk.Label()
|
||||
label.set_use_markup(True)
|
||||
label.set_wrap(True)
|
||||
label.set_xalign(0)
|
||||
label.set_wrap_mode(Pango.WrapMode.WORD)
|
||||
label.set_markup(result)
|
||||
label.set_justify(Gtk.Justification.LEFT)
|
||||
label.set_valign(Gtk.Align.START)
|
||||
label.set_hexpand(True)
|
||||
label.set_halign(Gtk.Align.START)
|
||||
self.content.append(label)
|
||||
|
||||
if not isinstance(line, str):
|
||||
result = "\n".join(line)
|
||||
else:
|
||||
result = line.strip()
|
||||
|
||||
self.content.append(CodeBlock(result))
|
||||
result = ""
|
||||
else:
|
||||
result += f"{line}\n"
|
||||
|
||||
else:
|
||||
if not result.strip() == "<tt></tt>`":
|
||||
label = Gtk.Label()
|
||||
label.set_use_markup(True)
|
||||
label.set_wrap(True)
|
||||
label.set_xalign(0)
|
||||
label.set_wrap_mode(Pango.WrapMode.WORD)
|
||||
label.set_markup(result)
|
||||
label.set_justify(Gtk.Justification.LEFT)
|
||||
label.set_valign(Gtk.Align.START)
|
||||
label.set_hexpand(True)
|
||||
label.set_halign(Gtk.Align.START)
|
||||
self.content.append(label)
|
||||
else:
|
||||
picture = Gtk.Picture()
|
||||
picture.set_halign(Gtk.Align.CENTER)
|
||||
picture.set_can_shrink(True)
|
||||
picture.set_content_fit(Gtk.ContentFit.FILL)
|
||||
picture.set_visible(True)
|
||||
picture.add_css_class("card")
|
||||
picture.set_margin_start(12)
|
||||
picture.set_margin_end(12)
|
||||
#print(self.content.get_width(), self.content.get_height())
|
||||
picture.set_size_request(270, 270)
|
||||
image.save("/tmp/image.png")
|
||||
picture.set_file(Gio.File.new_for_path("/tmp/image.png"))
|
||||
self.content.append(picture)
|
||||
|
||||
t = self.item["role"].lower()
|
||||
|
||||
if t == self.app.user_name.lower() or t == "user": # User
|
||||
self.message_bubble.add_css_class("message-bubble-user")
|
||||
self.avatar.add_css_class("avatar-user")
|
||||
|
|
Loading…
Reference in New Issue