Fixing problem using openapi

This commit is contained in:
«Juan
2024-01-21 18:16:04 +01:00
committed by 0xmrtt
parent c54e7acc08
commit 62367ce9f4

View File

@@ -1,7 +1,9 @@
from .base import BaseProvider from .base import BaseProvider
import openai import openai
from openai import OpenAI
import socket import socket
import os
import httpx
from gi.repository import Gtk, Adw, GLib from gi.repository import Gtk, Adw, GLib
@@ -9,15 +11,18 @@ from gi.repository import Gtk, Adw, GLib
class BaseOpenAIProvider(BaseProvider): class BaseOpenAIProvider(BaseProvider):
model = None model = None
api_key_title = "API Key" api_key_title = "API Key"
chat = openai.ChatCompletion client = OpenAI(
# This is the default and can be omitted
api_key=os.environ.get("OPENAI_API_KEY"),
)
def __init__(self, app, window): def __init__(self, app, window):
super().__init__(app, window) super().__init__(app, window)
if self.data.get("api_key"): if self.data.get("api_key"):
openai.api_key = self.data["api_key"] self.client.api_key = self.data["api_key"]
if self.data.get("api_base"): if self.data.get("api_base"):
openai.api_base = self.data["api_base"] self.client.base_url = httpx.URL(self.data["api_base"])
def ask(self, prompt, chat): def ask(self, prompt, chat):
_chat = [] _chat = []
@@ -29,26 +34,21 @@ class BaseOpenAIProvider(BaseProvider):
_chat.append({"role": role, "content": c["content"]}) _chat.append({"role": role, "content": c["content"]})
chat = _chat chat = _chat
if self.data.get("api_key"):
openai.api_key = self.data["api_key"]
if self.data.get("api_base"):
openai.api_base = self.data["api_base"]
if self.model: if self.model:
prompt = self.chunk(prompt) prompt = self.chunk(prompt)
try: try:
print(chat) print(chat)
response = self.chat.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=chat, messages=chat,
).choices[0].message.content ).choices[0].message.content
except openai.error.AuthenticationError: except openai.AuthenticationError:
return _("Your API key is invalid, please check your preferences.") return _("Your API key is invalid, please check your preferences.")
except openai.error.InvalidRequestError: except openai.BadRequestError:
return _("You don't have access to this model, please check your plan and billing details.") return _("You don't have access to this model, please check your plan and billing details.")
except openai.error.RateLimitError: except openai.RateLimitError:
return _("You exceeded your current quota, please check your plan and billing details.") return _("You exceeded your current quota, please check your plan and billing details.")
except openai.error.APIError: except openai.APIConnectionError:
return _("I'm having trouble connecting to the API, please check your internet connection.") return _("I'm having trouble connecting to the API, please check your internet connection.")
except socket.gaierror: except socket.gaierror:
return _("I'm having trouble connecting to the API, please check your internet connection.") return _("I'm having trouble connecting to the API, please check your internet connection.")
@@ -64,7 +64,7 @@ class BaseOpenAIProvider(BaseProvider):
self.api_row = Adw.PasswordEntryRow() self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply) self.api_row.connect("apply", self.on_apply)
self.api_row.props.text = openai.api_key or "" self.api_row.props.text = self.client.api_key or ""
self.api_row.props.title = self.api_key_title self.api_row.props.title = self.api_key_title
self.api_row.set_show_apply_button(True) self.api_row.set_show_apply_button(True)
self.api_row.add_suffix(self.how_to_get_a_token()) self.api_row.add_suffix(self.how_to_get_a_token())
@@ -72,7 +72,7 @@ class BaseOpenAIProvider(BaseProvider):
self.api_url_row = Adw.EntryRow() self.api_url_row = Adw.EntryRow()
self.api_url_row.connect("apply", self.on_apply) self.api_url_row.connect("apply", self.on_apply)
self.api_url_row.props.text = openai.api_base or "" self.api_url_row.props.text=str(self.client.base_url) or ""
self.api_url_row.props.title = "API Url" self.api_url_row.props.title = "API Url"
self.api_url_row.set_show_apply_button(True) self.api_url_row.set_show_apply_button(True)
self.api_url_row.add_suffix(self.how_to_get_base_url()) self.api_url_row.add_suffix(self.how_to_get_base_url())
@@ -82,11 +82,11 @@ class BaseOpenAIProvider(BaseProvider):
def on_apply(self, widget): def on_apply(self, widget):
api_key = self.api_row.get_text() api_key = self.api_row.get_text()
openai.api_key = api_key self.client.api_key = api_key
openai.api_base = self.api_url_row.get_text() self.client.base_url = self.api_url_row.get_text()
self.data["api_key"] = openai.api_key self.data["api_key"] = self.client.api_key
self.data["api_base"] = openai.api_base self.data["api_base"] = str(self.client.base_url)
def how_to_get_base_url(self): def how_to_get_base_url(self):