Bavarder/src/providers/hfbasechat.py

70 lines
2.5 KiB
Python

from .base import BaseProvider, ProviderType
import requests
from gi.repository import Gtk, Adw, GLib
class BaseHFChatProvider(BaseProvider):
provider = None
chat_mode = True
def ask(self, prompt, chat, **kwargs):
chat = chat["content"]
API_URL = f"https://api-inference.huggingface.co/models/{self.provider}"
def query(payload):
if self.data.get('api_key'):
headers = {"Authorization": f"Bearer {self.data['api_key']}"}
response = requests.post(API_URL, json=payload, headers=headers)
else:
response = requests.post(API_URL, json=payload)
return response.json()
if self.provider_type == ProviderType.CHAT:
output = query({
"inputs": {
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
"generated_responses": [i['content'] for i in chat if i['role'] == self.app.bot_name],
"text": prompt
},
})
else:
prompt = self.make_prompt(prompt, chat)
output = query({
"inputs": prompt,
})
if 'generated_text' in output:
return output['generated_text']
elif 'error' in output:
match output['error']:
case "Rate limit reached. Please log in or use your apiToken":
return _("You've reached the rate limit! Please add a token to the preferences. You can get the token by following this [guide](https://bavarder.codeberg.page/help/huggingface/)")
elif isinstance(output, list):
if 'generated_text' in output[0]:
return output[0]['generated_text']
else:
return _("Sorry, I don't know what to say! (Error: {output})")
def get_settings_rows(self):
self.rows = []
self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply)
self.api_row.props.text = self.data.get('api_key') or ""
self.api_row.props.title = _("API Key")
self.api_row.set_show_apply_button(True)
self.api_row.add_suffix(self.how_to_get_a_token())
self.rows.append(self.api_row)
return self.rows
def on_apply(self, widget):
api_key = self.api_row.get_text()
self.data["api_key"] = api_key
def make_prompt(self, prompt, chat):
return prompt