fix: rework provider loading mechanism and fix bug
This commit is contained in:
parent
5f8a712c5a
commit
f818bd4b50
34
src/main.py
34
src/main.py
|
@ -117,6 +117,7 @@ class ImaginerApplication(Adw.Application):
|
||||||
GLib.Variant("s", self.latest_provider),
|
GLib.Variant("s", self.latest_provider),
|
||||||
self.on_set_provider_action
|
self.on_set_provider_action
|
||||||
)
|
)
|
||||||
|
self.providers = {}
|
||||||
|
|
||||||
|
|
||||||
def quitting(self, *args, **kwargs):
|
def quitting(self, *args, **kwargs):
|
||||||
|
@ -187,30 +188,31 @@ class ImaginerApplication(Adw.Application):
|
||||||
self.menu_model.append_item(Gio.MenuItem.new(label=_("New Window"), detailed_action="app.new"))
|
self.menu_model.append_item(Gio.MenuItem.new(label=_("New Window"), detailed_action="app.new"))
|
||||||
|
|
||||||
section_menu = Gio.Menu()
|
section_menu = Gio.Menu()
|
||||||
|
|
||||||
provider_menu = Gio.Menu()
|
provider_menu = Gio.Menu()
|
||||||
|
|
||||||
|
|
||||||
self.providers = {}
|
|
||||||
self.providers_data = self.settings.get_value("providers-data")
|
self.providers_data = self.settings.get_value("providers-data")
|
||||||
|
|
||||||
for provider in self.enabled_providers:
|
for provider in self.enabled_providers:
|
||||||
try:
|
if provider in self.providers:
|
||||||
item = PROVIDERS[provider]
|
p = self.providers[provider]
|
||||||
item_model = Gio.MenuItem()
|
name = p.name
|
||||||
item_model.set_label(item.name)
|
slug = p.slug
|
||||||
item_model.set_action_and_target_value(
|
|
||||||
"app.set_provider",
|
|
||||||
GLib.Variant("s", item.slug))
|
|
||||||
provider_menu.append_item(item_model)
|
|
||||||
except KeyError:
|
|
||||||
print("Provider", provider, "not found")
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.providers[item.slug] # doesn't re load if already loaded
|
p = PROVIDERS[provider]
|
||||||
|
name = p.name
|
||||||
|
slug = p.slug
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self.providers[item.slug] = PROVIDERS[provider](window, self)
|
continue
|
||||||
|
else:
|
||||||
|
self.providers[slug] = PROVIDERS[provider](window, self)
|
||||||
|
|
||||||
|
item_model = Gio.MenuItem()
|
||||||
|
item_model.set_label(name)
|
||||||
|
item_model.set_action_and_target_value(
|
||||||
|
"app.set_provider",
|
||||||
|
GLib.Variant("s", slug))
|
||||||
|
provider_menu.append_item(item_model)
|
||||||
|
|
||||||
section_menu.append_submenu(_("Providers"), provider_menu)
|
section_menu.append_submenu(_("Providers"), provider_menu)
|
||||||
|
|
||||||
|
|
|
@ -17,19 +17,13 @@ class Preferences(Adw.PreferencesWindow):
|
||||||
self.setup_providers()
|
self.setup_providers()
|
||||||
|
|
||||||
def setup_providers(self):
|
def setup_providers(self):
|
||||||
# for provider in self.app.providers.values():
|
|
||||||
# try:
|
|
||||||
# self.provider_group.add(provider.preferences(self))
|
|
||||||
# except TypeError: # no prefs
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# row = Adw.ActionRow()
|
|
||||||
# row.props.title = "No providers available"
|
|
||||||
# self.provider_group.add(row)
|
|
||||||
for provider in PROVIDERS.values():
|
for provider in PROVIDERS.values():
|
||||||
try:
|
if provider.slug in self.app.providers:
|
||||||
self.provider_group.add(
|
self.provider_group.add(
|
||||||
provider(self.app.win, self.app).preferences(self)
|
self.app.providers[provider.slug].preferences(win=self.app.win)
|
||||||
)
|
)
|
||||||
except TypeError:
|
else:
|
||||||
pass
|
self.provider_group.add(
|
||||||
|
provider(self.app.win, self.app).preferences(win=self.app.win)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,81 +19,75 @@ class BaseHFProvider(ImaginerProvider):
|
||||||
self.api_key = None
|
self.api_key = None
|
||||||
|
|
||||||
def ask(self, prompt, negative_prompt):
|
def ask(self, prompt, negative_prompt):
|
||||||
try:
|
if self.model:
|
||||||
payload = json.dumps(
|
try:
|
||||||
{
|
payload = json.dumps(
|
||||||
"inputs": prompt,
|
{
|
||||||
"negative_prompts": negative_prompt if negative_prompt else "",
|
"inputs": prompt,
|
||||||
}
|
"negative_prompts": negative_prompt if negative_prompt else "",
|
||||||
)
|
}
|
||||||
headers = {"Content-Type": "application/json"}
|
)
|
||||||
if self.require_api_key and self.api_key:
|
headers = {"Content-Type": "application/json"}
|
||||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
if self.api_key:
|
||||||
url = f"https://api-inference.huggingface.co/models/{self.model}"
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
url = f"https://api-inference.huggingface.co/models/{self.model}"
|
||||||
if response.status_code == 403:
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
self.no_api_key()
|
if response.status_code == 403:
|
||||||
return ""
|
self.no_api_key()
|
||||||
elif response.status_code != 200:
|
return ""
|
||||||
self.win.banner.props.title = response.json()["error"]
|
elif response.status_code != 200:
|
||||||
self.win.banner.props.button_label = ""
|
self.win.banner.props.title = response.json()["error"]
|
||||||
self.win.banner.set_revealed(True)
|
self.win.banner.props.button_label = ""
|
||||||
return ""
|
|
||||||
response = response.content
|
|
||||||
except KeyError:
|
|
||||||
print("KeyError")
|
|
||||||
pass
|
|
||||||
except socket.gaierror:
|
|
||||||
self.no_connection()
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
self.hide_banner()
|
|
||||||
if response:
|
|
||||||
try:
|
|
||||||
return Image.open(io.BytesIO(response))
|
|
||||||
except UnidentifiedImageError:
|
|
||||||
error = json.loads(response)["error"]
|
|
||||||
self.win.banner.set_title(error)
|
|
||||||
self.win.banner.set_revealed(True)
|
self.win.banner.set_revealed(True)
|
||||||
return None
|
return ""
|
||||||
|
response = response.content
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
except socket.gaierror:
|
||||||
|
self.no_connection()
|
||||||
|
return ""
|
||||||
else:
|
else:
|
||||||
print("No response")
|
self.hide_banner()
|
||||||
return None
|
if response:
|
||||||
|
try:
|
||||||
@property
|
return Image.open(io.BytesIO(response))
|
||||||
def require_api_key(self):
|
except UnidentifiedImageError:
|
||||||
return True
|
error = json.loads(response)["error"]
|
||||||
|
self.win.banner.set_title(error)
|
||||||
|
self.win.banner.set_revealed(True)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("No response")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self.no_api_key(title="No model selected, you can choose one in preferences")
|
||||||
|
return ""
|
||||||
|
|
||||||
def preferences(self, win):
|
def preferences(self, win):
|
||||||
if self.require_api_key:
|
self.expander = Adw.ExpanderRow()
|
||||||
self.expander = Adw.ExpanderRow()
|
self.expander.props.title = self.name
|
||||||
self.expander.props.title = self.name
|
|
||||||
|
|
||||||
self.expander.add_action(self.about())
|
self.expander.add_action(self.about())
|
||||||
self.expander.add_action(self.enable_switch())
|
self.expander.add_action(self.enable_switch())
|
||||||
|
|
||||||
self.api_row = Adw.PasswordEntryRow()
|
self.api_row = Adw.PasswordEntryRow()
|
||||||
self.api_row.connect("apply", self.on_apply)
|
self.api_row.connect("apply", self.on_apply)
|
||||||
self.api_row.props.title = "API Key"
|
self.api_row.props.title = "API Key"
|
||||||
self.api_row.props.text = self.api_key or ""
|
self.api_row.props.text = self.api_key or ""
|
||||||
self.api_row.add_suffix(self.how_to_get_a_token())
|
self.api_row.add_suffix(self.how_to_get_a_token())
|
||||||
self.api_row.set_show_apply_button(True)
|
self.api_row.set_show_apply_button(True)
|
||||||
self.expander.add_row(self.api_row)
|
self.expander.add_row(self.api_row)
|
||||||
|
|
||||||
return self.expander
|
return self.expander
|
||||||
else:
|
|
||||||
return self.no_preferences(win)
|
|
||||||
|
|
||||||
def on_apply(self, widget):
|
def on_apply(self, widget):
|
||||||
self.hide_banner()
|
self.hide_banner()
|
||||||
self.api_key = self.api_row.get_text()
|
self.api_key = self.api_row.get_text()
|
||||||
print(self.api_key)
|
self.app.save_providers()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
if self.require_api_key:
|
return {"api_key": self.api_key}
|
||||||
return {"api_key": self.api_key}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load(self, data):
|
def load(self, data):
|
||||||
if self.require_api_key:
|
if data["api_key"]:
|
||||||
self.api_key = data["api_key"]
|
self.api_key = data["api_key"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user