Fixing problem using openapi
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
import socket
|
import socket
|
||||||
|
import os
|
||||||
|
import httpx
|
||||||
|
|
||||||
from gi.repository import Gtk, Adw, GLib
|
from gi.repository import Gtk, Adw, GLib
|
||||||
|
|
||||||
@@ -9,15 +11,18 @@ from gi.repository import Gtk, Adw, GLib
|
|||||||
class BaseOpenAIProvider(BaseProvider):
|
class BaseOpenAIProvider(BaseProvider):
|
||||||
model = None
|
model = None
|
||||||
api_key_title = "API Key"
|
api_key_title = "API Key"
|
||||||
chat = openai.ChatCompletion
|
client = OpenAI(
|
||||||
|
# This is the default and can be omitted
|
||||||
|
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, app, window):
|
def __init__(self, app, window):
|
||||||
super().__init__(app, window)
|
super().__init__(app, window)
|
||||||
|
|
||||||
if self.data.get("api_key"):
|
if self.data.get("api_key"):
|
||||||
openai.api_key = self.data["api_key"]
|
self.client.api_key = self.data["api_key"]
|
||||||
if self.data.get("api_base"):
|
if self.data.get("api_base"):
|
||||||
openai.api_base = self.data["api_base"]
|
self.client.base_url = httpx.URL(self.data["api_base"])
|
||||||
|
|
||||||
def ask(self, prompt, chat):
|
def ask(self, prompt, chat):
|
||||||
_chat = []
|
_chat = []
|
||||||
@@ -29,26 +34,21 @@ class BaseOpenAIProvider(BaseProvider):
|
|||||||
_chat.append({"role": role, "content": c["content"]})
|
_chat.append({"role": role, "content": c["content"]})
|
||||||
chat = _chat
|
chat = _chat
|
||||||
|
|
||||||
if self.data.get("api_key"):
|
|
||||||
openai.api_key = self.data["api_key"]
|
|
||||||
if self.data.get("api_base"):
|
|
||||||
openai.api_base = self.data["api_base"]
|
|
||||||
|
|
||||||
if self.model:
|
if self.model:
|
||||||
prompt = self.chunk(prompt)
|
prompt = self.chunk(prompt)
|
||||||
try:
|
try:
|
||||||
print(chat)
|
print(chat)
|
||||||
response = self.chat.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=chat,
|
messages=chat,
|
||||||
).choices[0].message.content
|
).choices[0].message.content
|
||||||
except openai.error.AuthenticationError:
|
except openai.AuthenticationError:
|
||||||
return _("Your API key is invalid, please check your preferences.")
|
return _("Your API key is invalid, please check your preferences.")
|
||||||
except openai.error.InvalidRequestError:
|
except openai.BadRequestError:
|
||||||
return _("You don't have access to this model, please check your plan and billing details.")
|
return _("You don't have access to this model, please check your plan and billing details.")
|
||||||
except openai.error.RateLimitError:
|
except openai.RateLimitError:
|
||||||
return _("You exceeded your current quota, please check your plan and billing details.")
|
return _("You exceeded your current quota, please check your plan and billing details.")
|
||||||
except openai.error.APIError:
|
except openai.APIConnectionError:
|
||||||
return _("I'm having trouble connecting to the API, please check your internet connection.")
|
return _("I'm having trouble connecting to the API, please check your internet connection.")
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
return _("I'm having trouble connecting to the API, please check your internet connection.")
|
return _("I'm having trouble connecting to the API, please check your internet connection.")
|
||||||
@@ -64,7 +64,7 @@ class BaseOpenAIProvider(BaseProvider):
|
|||||||
|
|
||||||
self.api_row = Adw.PasswordEntryRow()
|
self.api_row = Adw.PasswordEntryRow()
|
||||||
self.api_row.connect("apply", self.on_apply)
|
self.api_row.connect("apply", self.on_apply)
|
||||||
self.api_row.props.text = openai.api_key or ""
|
self.api_row.props.text = self.client.api_key or ""
|
||||||
self.api_row.props.title = self.api_key_title
|
self.api_row.props.title = self.api_key_title
|
||||||
self.api_row.set_show_apply_button(True)
|
self.api_row.set_show_apply_button(True)
|
||||||
self.api_row.add_suffix(self.how_to_get_a_token())
|
self.api_row.add_suffix(self.how_to_get_a_token())
|
||||||
@@ -72,7 +72,7 @@ class BaseOpenAIProvider(BaseProvider):
|
|||||||
|
|
||||||
self.api_url_row = Adw.EntryRow()
|
self.api_url_row = Adw.EntryRow()
|
||||||
self.api_url_row.connect("apply", self.on_apply)
|
self.api_url_row.connect("apply", self.on_apply)
|
||||||
self.api_url_row.props.text = openai.api_base or ""
|
self.api_url_row.props.text=str(self.client.base_url) or ""
|
||||||
self.api_url_row.props.title = "API Url"
|
self.api_url_row.props.title = "API Url"
|
||||||
self.api_url_row.set_show_apply_button(True)
|
self.api_url_row.set_show_apply_button(True)
|
||||||
self.api_url_row.add_suffix(self.how_to_get_base_url())
|
self.api_url_row.add_suffix(self.how_to_get_base_url())
|
||||||
@@ -82,11 +82,11 @@ class BaseOpenAIProvider(BaseProvider):
|
|||||||
|
|
||||||
def on_apply(self, widget):
|
def on_apply(self, widget):
|
||||||
api_key = self.api_row.get_text()
|
api_key = self.api_row.get_text()
|
||||||
openai.api_key = api_key
|
self.client.api_key = api_key
|
||||||
openai.api_base = self.api_url_row.get_text()
|
self.client.base_url = self.api_url_row.get_text()
|
||||||
|
|
||||||
self.data["api_key"] = openai.api_key
|
self.data["api_key"] = self.client.api_key
|
||||||
self.data["api_base"] = openai.api_base
|
self.data["api_base"] = str(self.client.base_url)
|
||||||
|
|
||||||
|
|
||||||
def how_to_get_base_url(self):
|
def how_to_get_base_url(self):
|
||||||
|
Reference in New Issue
Block a user