diff --git a/src/provider/huggingchat.py b/src/provider/huggingchat.py index e8e58b5..8d681b8 100644 --- a/src/provider/huggingchat.py +++ b/src/provider/huggingchat.py @@ -1,49 +1,6 @@ -from .base import BavarderProvider +from .huggingchatbase import BaseHuggingChatProvider -from hgchat import HGChat -import socket - - -from gi.repository import Gtk, Adw, GLib - - -class HuggingChatProvider(BavarderProvider): +class HuggingChatProvider(BaseHuggingChatProvider): name = "Hugging Chat" slug = "huggingchat" - - def __init__(self, win, app, *args, **kwargs): - super().__init__(win, app, *args, **kwargs) - self.chat = HGChat() - - def ask(self, prompt): - try: - response = self.chat.ask(prompt) - except socket.gaierror: - self.no_connection() - return "" - except Exception as e: - self.win.banner.props.title = str(e) - self.win.banner.props.button_label = "" - self.win.banner.set_revealed(True) - return "" - else: - self.win.banner.set_revealed(False) - r = "" - for i in response: - char = i["token"]["text"] - if char == "": - r += "\n" - else: - r += char - GLib.idle_add(self.update_response, r) - return r - - @property - def require_api_key(self): - return False - - def save(self): - return {} - - def load(self, data): - pass + model = "OpenAssistant/oasst-sft-6-llama-30b-xor" diff --git a/src/provider/huggingchatbase.py b/src/provider/huggingchatbase.py new file mode 100644 index 0000000..2d76e93 --- /dev/null +++ b/src/provider/huggingchatbase.py @@ -0,0 +1,50 @@ +from .base import BavarderProvider + +from hgchat import HGChat +import socket + + +from gi.repository import Gtk, Adw, GLib + + +class BaseHuggingChatProvider(BavarderProvider): + name = "Hugging Chat" + slug = "huggingchat" + model = None + + def __init__(self, win, app, *args, **kwargs): + super().__init__(win, app, *args, **kwargs) + self.chat = HGChat(self.model) + + def ask(self, prompt): + try: + response = self.chat.ask(prompt) + except socket.gaierror: + self.no_connection() + return "" + except Exception as e: + self.win.banner.props.title = str(e) + self.win.banner.props.button_label = "" + self.win.banner.set_revealed(True) + return "" + else: + self.win.banner.set_revealed(False) + r = "" + for i in response: + char = i["token"]["text"] + if char == "": + r += "\n" + else: + r += char + GLib.idle_add(self.update_response, r) + return r + + @property + def require_api_key(self): + return False + + def save(self): + return {} + + def load(self, data): + pass