feat: add openai image

This commit is contained in:
0xmrtt 2024-02-25 13:48:14 +01:00
parent 14c3dbdcf1
commit b80b5998e4
5 changed files with 128 additions and 1 deletions

View File

@ -14,6 +14,8 @@ from .stablediffusion import StableDiffusionProvider
from .analogdiffusion import AnalogDiffusionProvider from .analogdiffusion import AnalogDiffusionProvider
from .nitrodiffusion import NitroDiffusionProvider from .nitrodiffusion import NitroDiffusionProvider
from .openjourney import OpenJourneyProvider from .openjourney import OpenJourneyProvider
from .openaiimage import DallE2, DallE3
from .portraitplus import PortraitPlusProvider
PROVIDERS = { PROVIDERS = {
AIHordeProvider, AIHordeProvider,
@ -29,6 +31,9 @@ PROVIDERS = {
AnalogDiffusionProvider, AnalogDiffusionProvider,
NitroDiffusionProvider, NitroDiffusionProvider,
OpenJourneyProvider, OpenJourneyProvider,
DallE2,
DallE3,
PortraitPlusProvider,
# StableBeluga2Provider, # StableBeluga2Provider,
# HuggingFaceOpenAssistantSFT1PythiaProvider, # HuggingFaceOpenAssistantSFT1PythiaProvider,
# RobertaSquad2Provider # RobertaSquad2Provider

View File

@ -3,4 +3,3 @@ from .basehfimage import BaseHFImageProvider
class AnalogDiffusionProvider(BaseHFImageProvider): class AnalogDiffusionProvider(BaseHFImageProvider):
name = "Analog Diffusion" name = "Analog Diffusion"
provider = "wavymulder/Analog-Diffusion" provider = "wavymulder/Analog-Diffusion"
3

View File

@ -18,9 +18,11 @@ providers_sources = [
'openai.py', 'openai.py',
'openaigpt35turbo.py', 'openaigpt35turbo.py',
'openaigpt4.py', 'openaigpt4.py',
'openaiimage.py',
'openassistantsft1pythia12b.py', 'openassistantsft1pythia12b.py',
'openjourney.py', 'openjourney.py',
'petals.py', 'petals.py',
'portraitplus.py',
'provider_item.py', 'provider_item.py',
'stablebeluga2.py', 'stablebeluga2.py',
'robertasquad2.py', 'robertasquad2.py',

View File

@ -0,0 +1,116 @@
from .baseimage import BaseImageProvider
import openai
from openai import OpenAI
import socket
import os
import json
from gi.repository import Gtk, Adw, GLib
class BaseOpenAIImageProvider(BaseProvider):
model = None
api_key_title = "API Key"
def __init__(self, app, window):
super().__init__(app, window)
try:
self.client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
)
except openai.OpenAIError:
self.client = OpenAI(
api_key="",
)
if self.data.get("api_key"):
self.client.api_key = self.data["api_key"]
if self.data.get("api_base"):
self.client.base_url = self.data["api_base"]
def ask(self, prompt, chat):
if self.model:
prompt = self.chunk(prompt)
try:
response = client.images.generate(
model=self.model,
prompt=self.prompt,
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
image_bytes = requests.get(image_url).content
except openai.AuthenticationError:
return _("Your API key is invalid, please check your preferences.")
except openai.BadRequestError:
return _("You don't have access to this model, please check your plan and billing details.")
except openai.RateLimitError:
return _("You exceeded your current quota, please check your plan and billing details.")
except openai.APIConnectionError:
return _("I'm having trouble connecting to the API, please check your internet connection.")
except socket.gaierror:
return _("I'm having trouble connecting to the API, please check your internet connection.")
else:
if image_bytes:
try:
return Image.open(io.BytesIO(image_bytes))
except UnidentifiedImageError:
error = json.loads(image_bytes)["error"]
return error
else:
return None
else:
return _("No model selected, you can choose one in preferences")
def get_settings_rows(self):
self.rows = []
self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply)
self.api_row.props.text = self.client.api_key or ""
self.api_row.props.title = self.api_key_title
self.api_row.set_show_apply_button(True)
self.api_row.add_suffix(self.how_to_get_a_token())
self.rows.append(self.api_row)
self.api_url_row = Adw.EntryRow()
self.api_url_row.connect("apply", self.on_apply)
self.api_url_row.props.text=str(self.client.base_url) or ""
self.api_url_row.props.title = "API Url"
self.api_url_row.set_show_apply_button(True)
self.api_url_row.add_suffix(self.how_to_get_base_url())
self.rows.append(self.api_url_row)
return self.rows
def on_apply(self, widget):
api_key = self.api_row.get_text()
self.client.api_key = api_key
self.client.base_url = self.api_url_row.get_text()
self.data["api_key"] = self.client.api_key
self.data["api_base"] = str(self.client.base_url)
def how_to_get_base_url(self):
about_button = Gtk.Button()
about_button.set_icon_name("dialog-information-symbolic")
about_button.set_tooltip_text("How to choose base url")
about_button.add_css_class("flat")
about_button.set_valign(Gtk.Align.CENTER)
about_button.connect("clicked", self.open_documentation)
return about_button
class DallE2(BaseOpenAIImageProvider):
name = "DALL·E 2"
model = "dall-e-2"
class DallE3(BaseOpenAIImageProvider):
name = "DALL·E 3"
model = "dall-e-3"

View File

@ -0,0 +1,5 @@
from .basehfimage import BaseHFImageProvider
class PortraitPlusProvider(BaseHFImageProvider):
name = "Portrait Plus"
model = "wavymulder/portraitplus"