diff --git a/src/main.py b/src/main.py index e1e9e41..8be65b9 100644 --- a/src/main.py +++ b/src/main.py @@ -117,6 +117,7 @@ class ImaginerApplication(Adw.Application): GLib.Variant("s", self.latest_provider), self.on_set_provider_action ) + self.providers = {} def quitting(self, *args, **kwargs): @@ -187,30 +188,31 @@ class ImaginerApplication(Adw.Application): self.menu_model.append_item(Gio.MenuItem.new(label=_("New Window"), detailed_action="app.new")) section_menu = Gio.Menu() - provider_menu = Gio.Menu() - - self.providers = {} self.providers_data = self.settings.get_value("providers-data") for provider in self.enabled_providers: - try: - item = PROVIDERS[provider] - item_model = Gio.MenuItem() - item_model.set_label(item.name) - item_model.set_action_and_target_value( - "app.set_provider", - GLib.Variant("s", item.slug)) - provider_menu.append_item(item_model) - except KeyError: - print("Provider", provider, "not found") - continue + if provider in self.providers: + p = self.providers[provider] + name = p.name + slug = p.slug else: try: - self.providers[item.slug] # doesn't re load if already loaded + p = PROVIDERS[provider] + name = p.name + slug = p.slug except KeyError: - self.providers[item.slug] = PROVIDERS[provider](window, self) + continue + else: + self.providers[slug] = PROVIDERS[provider](window, self) + + item_model = Gio.MenuItem() + item_model.set_label(name) + item_model.set_action_and_target_value( + "app.set_provider", + GLib.Variant("s", slug)) + provider_menu.append_item(item_model) section_menu.append_submenu(_("Providers"), provider_menu) diff --git a/src/preferences.py b/src/preferences.py index 490ce2b..c824203 100644 --- a/src/preferences.py +++ b/src/preferences.py @@ -17,19 +17,13 @@ class Preferences(Adw.PreferencesWindow): self.setup_providers() def setup_providers(self): - # for provider in self.app.providers.values(): - # try: - # self.provider_group.add(provider.preferences(self)) - # except TypeError: # no prefs - # pass - # else: - # row = Adw.ActionRow() - # row.props.title = "No providers available" - # self.provider_group.add(row) for provider in PROVIDERS.values(): - try: + if provider.slug in self.app.providers: self.provider_group.add( - provider(self.app.win, self.app).preferences(self) + self.app.providers[provider.slug].preferences(win=self.app.win) ) - except TypeError: - pass + else: + self.provider_group.add( + provider(self.app.win, self.app).preferences(win=self.app.win) + ) + diff --git a/src/provider/huggingface.py b/src/provider/huggingface.py index a0a9600..ed8355c 100644 --- a/src/provider/huggingface.py +++ b/src/provider/huggingface.py @@ -19,81 +19,75 @@ class BaseHFProvider(ImaginerProvider): self.api_key = None def ask(self, prompt, negative_prompt): - try: - payload = json.dumps( - { - "inputs": prompt, - "negative_prompts": negative_prompt if negative_prompt else "", - } - ) - headers = {"Content-Type": "application/json"} - if self.require_api_key and self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - url = f"https://api-inference.huggingface.co/models/{self.model}" - response = requests.request("POST", url, headers=headers, data=payload) - if response.status_code == 403: - self.no_api_key() - return "" - elif response.status_code != 200: - self.win.banner.props.title = response.json()["error"] - self.win.banner.props.button_label = "" - self.win.banner.set_revealed(True) - return "" - response = response.content - except KeyError: - print("KeyError") - pass - except socket.gaierror: - self.no_connection() - return "" - else: - self.hide_banner() - if response: - try: - return Image.open(io.BytesIO(response)) - except UnidentifiedImageError: - error = json.loads(response)["error"] - self.win.banner.set_title(error) + if self.model: + try: + payload = json.dumps( + { + "inputs": prompt, + "negative_prompts": negative_prompt if negative_prompt else "", + } + ) + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + url = f"https://api-inference.huggingface.co/models/{self.model}" + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 403: + self.no_api_key() + return "" + elif response.status_code != 200: + self.win.banner.props.title = response.json()["error"] + self.win.banner.props.button_label = "" self.win.banner.set_revealed(True) - return None + return "" + response = response.content + except KeyError: + pass + except socket.gaierror: + self.no_connection() + return "" else: - print("No response") - return None - - @property - def require_api_key(self): - return True + self.hide_banner() + if response: + try: + return Image.open(io.BytesIO(response)) + except UnidentifiedImageError: + error = json.loads(response)["error"] + self.win.banner.set_title(error) + self.win.banner.set_revealed(True) + return None + else: + print("No response") + return None + else: + self.no_api_key(title="No model selected, you can choose one in preferences") + return "" def preferences(self, win): - if self.require_api_key: - self.expander = Adw.ExpanderRow() - self.expander.props.title = self.name + self.expander = Adw.ExpanderRow() + self.expander.props.title = self.name - self.expander.add_action(self.about()) - self.expander.add_action(self.enable_switch()) + self.expander.add_action(self.about()) + self.expander.add_action(self.enable_switch()) - self.api_row = Adw.PasswordEntryRow() - self.api_row.connect("apply", self.on_apply) - self.api_row.props.title = "API Key" - self.api_row.props.text = self.api_key or "" - self.api_row.add_suffix(self.how_to_get_a_token()) - self.api_row.set_show_apply_button(True) - self.expander.add_row(self.api_row) + self.api_row = Adw.PasswordEntryRow() + self.api_row.connect("apply", self.on_apply) + self.api_row.props.title = "API Key" + self.api_row.props.text = self.api_key or "" + self.api_row.add_suffix(self.how_to_get_a_token()) + self.api_row.set_show_apply_button(True) + self.expander.add_row(self.api_row) - return self.expander - else: - return self.no_preferences(win) + return self.expander def on_apply(self, widget): self.hide_banner() self.api_key = self.api_row.get_text() - print(self.api_key) + self.app.save_providers() def save(self): - if self.require_api_key: - return {"api_key": self.api_key} - return {} + return {"api_key": self.api_key} def load(self, data): - if self.require_api_key: + if data["api_key"]: self.api_key = data["api_key"]