provider(hgchat): refactor for adding new models
This commit is contained in:
parent
4d4bcf5957
commit
b559289c3d
|
@ -1,49 +1,6 @@
|
|||
from .base import BavarderProvider
|
||||
from .huggingchatbase import BaseHuggingChatProvider
|
||||
|
||||
from hgchat import HGChat
|
||||
import socket
|
||||
|
||||
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
|
||||
|
||||
class HuggingChatProvider(BavarderProvider):
|
||||
class HuggingChatProvider(BaseHuggingChatProvider):
|
||||
name = "Hugging Chat"
|
||||
slug = "huggingchat"
|
||||
|
||||
def __init__(self, win, app, *args, **kwargs):
|
||||
super().__init__(win, app, *args, **kwargs)
|
||||
self.chat = HGChat()
|
||||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
response = self.chat.ask(prompt)
|
||||
except socket.gaierror:
|
||||
self.no_connection()
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.win.banner.props.title = str(e)
|
||||
self.win.banner.props.button_label = ""
|
||||
self.win.banner.set_revealed(True)
|
||||
return ""
|
||||
else:
|
||||
self.win.banner.set_revealed(False)
|
||||
r = ""
|
||||
for i in response:
|
||||
char = i["token"]["text"]
|
||||
if char == "</s>":
|
||||
r += "\n"
|
||||
else:
|
||||
r += char
|
||||
GLib.idle_add(self.update_response, r)
|
||||
return r
|
||||
|
||||
@property
|
||||
def require_api_key(self):
|
||||
return False
|
||||
|
||||
def save(self):
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
||||
model = "OpenAssistant/oasst-sft-6-llama-30b-xor"
|
||||
|
|
50
src/provider/huggingchatbase.py
Normal file
50
src/provider/huggingchatbase.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from .base import BavarderProvider
|
||||
|
||||
from hgchat import HGChat
|
||||
import socket
|
||||
|
||||
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
|
||||
|
||||
class BaseHuggingChatProvider(BavarderProvider):
|
||||
name = "Hugging Chat"
|
||||
slug = "huggingchat"
|
||||
model = None
|
||||
|
||||
def __init__(self, win, app, *args, **kwargs):
|
||||
super().__init__(win, app, *args, **kwargs)
|
||||
self.chat = HGChat(self.model)
|
||||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
response = self.chat.ask(prompt)
|
||||
except socket.gaierror:
|
||||
self.no_connection()
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.win.banner.props.title = str(e)
|
||||
self.win.banner.props.button_label = ""
|
||||
self.win.banner.set_revealed(True)
|
||||
return ""
|
||||
else:
|
||||
self.win.banner.set_revealed(False)
|
||||
r = ""
|
||||
for i in response:
|
||||
char = i["token"]["text"]
|
||||
if char == "</s>":
|
||||
r += "\n"
|
||||
else:
|
||||
r += char
|
||||
GLib.idle_add(self.update_response, r)
|
||||
return r
|
||||
|
||||
@property
|
||||
def require_api_key(self):
|
||||
return False
|
||||
|
||||
def save(self):
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user