litert lm centric ???
This commit is contained in:
@@ -7,7 +7,6 @@
|
||||
<file preprocess="xml-stripblanks" alias="ui/save_dialog.ui">views/save_dialog.ui</file>
|
||||
<file preprocess="xml-stripblanks" alias="ui/thread_item.ui">widgets/thread_item.ui</file>
|
||||
<file preprocess="xml-stripblanks" alias="ui/item.ui">widgets/item.ui</file>
|
||||
<file preprocess="xml-stripblanks" alias="ui/provider_item.ui">providers/provider_item.ui</file>
|
||||
<file preprocess="xml-stripblanks" alias="ui/model_item.ui">widgets/model_item.ui</file>
|
||||
<file preprocess="xml-stripblanks" alias="ui/download_row.ui">widgets/download_row.ui</file>
|
||||
<file preprocess="xml-stripblanks" alias="ui/marketplace_item.ui">widgets/marketplace_item.ui</file>
|
||||
|
||||
+116
@@ -0,0 +1,116 @@
|
||||
import litert_lm
|
||||
import os
|
||||
import threading
|
||||
from gi.repository import GLib
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.model = None
|
||||
self.engine = None
|
||||
|
||||
def get_data(self):
|
||||
return self.app.data.get("models", {})
|
||||
|
||||
def get_model_path(self):
|
||||
data = self.get_data()
|
||||
|
||||
model_path = data.get("model_path", "")
|
||||
if model_path and os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
models_dir = os.path.join(self.app.user_cache_dir, "bavarder", "models")
|
||||
if os.path.exists(models_dir):
|
||||
for f in os.listdir(models_dir):
|
||||
if f.endswith(".litertlm"):
|
||||
return os.path.join(models_dir, f)
|
||||
|
||||
hf_model = data.get("hf_model", "")
|
||||
if hf_model:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import list_repo_files
|
||||
files = list_repo_files(hf_model, repo_type="model")
|
||||
litertlm_files = [f for f in files if f.endswith('.litertlm')]
|
||||
if litertlm_files:
|
||||
return hf_hub_download(
|
||||
repo_id=hf_model,
|
||||
filename=litertlm_files[0],
|
||||
cache_dir=self.app.user_cache_dir
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def load_model(self):
|
||||
if self.engine is not None:
|
||||
return
|
||||
|
||||
model_path = self.get_model_path()
|
||||
if not model_path:
|
||||
raise ValueError("No model available. Please download a model or set a model path.")
|
||||
|
||||
self.engine = litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU)
|
||||
|
||||
def ask(self, prompt, chat, callback, error_callback):
|
||||
def thread_run():
|
||||
try:
|
||||
self.load_model()
|
||||
|
||||
messages = []
|
||||
for msg in chat.get("content", []):
|
||||
role = msg.get("role", "user")
|
||||
if role == self.app.user_name:
|
||||
role = "user"
|
||||
elif role == self.app.bot_name:
|
||||
role = "assistant"
|
||||
else:
|
||||
role = "user"
|
||||
content = msg.get("content", "")
|
||||
messages.append({
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": content}]
|
||||
})
|
||||
|
||||
with self.engine.create_conversation(messages=messages) as conv:
|
||||
response = conv.send_message(prompt)
|
||||
GLib.idle_add(callback, response["content"][0]["text"])
|
||||
except Exception as e:
|
||||
GLib.idle_add(error_callback, str(e))
|
||||
|
||||
t = threading.Thread(target=thread_run)
|
||||
t.start()
|
||||
|
||||
def ask_async(self, prompt, chat, callback, error_callback):
|
||||
def thread_run():
|
||||
try:
|
||||
self.load_model()
|
||||
|
||||
messages = []
|
||||
for msg in chat.get("content", []):
|
||||
role = msg.get("role", "user")
|
||||
if role == self.app.user_name:
|
||||
role = "user"
|
||||
elif role == self.app.bot_name:
|
||||
role = "assistant"
|
||||
else:
|
||||
role = "user"
|
||||
content = msg.get("content", "")
|
||||
messages.append({
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": content}]
|
||||
})
|
||||
|
||||
with self.engine.create_conversation(messages=messages) as conv:
|
||||
stream = conv.send_message_async(prompt)
|
||||
for chunk in stream:
|
||||
for item in chunk.get("content", []):
|
||||
if item.get("type") == "text":
|
||||
GLib.idle_add(callback, item["text"])
|
||||
except Exception as e:
|
||||
GLib.idle_add(error_callback, str(e))
|
||||
|
||||
t = threading.Thread(target=thread_run)
|
||||
t.start()
|
||||
+5
-41
@@ -31,7 +31,7 @@ from .views.window import BavarderWindow
|
||||
from .views.about_window import AboutWindow
|
||||
from .views.preferences_window import PreferencesWindow
|
||||
from .constants import app_id
|
||||
from .providers import PROVIDERS
|
||||
from .llm import LLM
|
||||
|
||||
import json
|
||||
import os
|
||||
@@ -82,11 +82,6 @@ class BavarderApplication(Adw.Application):
|
||||
|
||||
self.data = {
|
||||
"chats": [],
|
||||
"providers": {
|
||||
"google-flan-t5-xxl": {"enabled": True, "data": {}},
|
||||
"gpt-2": {"enabled": True, "data": {}},
|
||||
|
||||
},
|
||||
"models": {}
|
||||
}
|
||||
|
||||
@@ -94,22 +89,13 @@ class BavarderApplication(Adw.Application):
|
||||
try:
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
self.data = json.load(f)
|
||||
except Exception: # if there is an error, we use a plain config
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.settings = Gio.Settings(schema_id=app_id)
|
||||
|
||||
self.local_mode = self.settings.get_boolean("local-mode")
|
||||
self.current_provider = self.settings.get_string("current-provider")
|
||||
self.model_name = self.settings.get_string("model")
|
||||
|
||||
self.create_stateful_action(
|
||||
"set_provider",
|
||||
GLib.VariantType.new("s"),
|
||||
GLib.Variant("s", self.current_provider),
|
||||
self.on_set_provider_action
|
||||
)
|
||||
|
||||
self.create_stateful_action(
|
||||
"set_model",
|
||||
GLib.VariantType.new("s"),
|
||||
@@ -121,10 +107,7 @@ class BavarderApplication(Adw.Application):
|
||||
self.user_name = self.settings.get_string("user-name")
|
||||
|
||||
self.user_cache_dir = user_cache_dir
|
||||
|
||||
def on_set_provider_action(self, action, *args):
|
||||
self.current_provider = args[0].get_string()
|
||||
Gio.SimpleAction.set_state(self.lookup_action("set_provider"), args[0])
|
||||
self.llm = LLM(self)
|
||||
|
||||
def on_set_model_action(self, action, *args):
|
||||
Gio.SimpleAction.set_state(self.lookup_action("set_model"), args[0])
|
||||
@@ -132,8 +115,6 @@ class BavarderApplication(Adw.Application):
|
||||
def save(self):
|
||||
with open(self.data_path, "w", encoding="utf-8") as f:
|
||||
self.data = json.dump(self.data, f)
|
||||
self.settings.set_boolean("local-mode", self.local_mode)
|
||||
self.settings.set_string("current-provider", self.current_provider)
|
||||
self.settings.set_string("model", self.model_name)
|
||||
self.settings.set_string("bot-name", self.bot_name)
|
||||
self.settings.set_string("user-name", self.user_name)
|
||||
@@ -189,15 +170,7 @@ class BavarderApplication(Adw.Application):
|
||||
|
||||
win.connect("close-request", self.on_close)
|
||||
|
||||
self.providers = {}
|
||||
|
||||
for provider in PROVIDERS:
|
||||
p = provider(self, win)
|
||||
|
||||
self.providers[p.slug] = p
|
||||
|
||||
win.load_model_selector()
|
||||
win.load_provider_selector()
|
||||
win.present()
|
||||
|
||||
|
||||
@@ -250,17 +223,8 @@ class BavarderApplication(Adw.Application):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def ask(self, prompt, chat):
|
||||
l = list(self.providers.values())
|
||||
|
||||
for p in l:
|
||||
if p.enabled and p.slug == self.current_provider:
|
||||
response = self.providers[self.current_provider].ask(prompt, chat)
|
||||
break
|
||||
else:
|
||||
response = _("Please enable a provider from the Dot Menu")
|
||||
|
||||
return response
|
||||
def ask(self, prompt, chat, callback, error_callback):
|
||||
self.llm.ask(prompt, chat, callback, error_callback)
|
||||
|
||||
def check_network(self):
|
||||
return False
|
||||
|
||||
+2
-3
@@ -15,7 +15,6 @@ blueprints = custom_target('blueprints',
|
||||
'widgets/download_row.blp',
|
||||
'widgets/marketplace_item.blp',
|
||||
'widgets/code_block.blp',
|
||||
'providers/provider_item.blp',
|
||||
),
|
||||
output: '.',
|
||||
command: [find_program('blueprint-compiler'), 'batch-compile', '@OUTPUT@', '@CURRENT_SOURCE_DIR@', '@INPUT@']
|
||||
@@ -49,11 +48,11 @@ configure_file(
|
||||
bavarder_sources = [
|
||||
'__init__.py',
|
||||
'main.py',
|
||||
'llm.py',
|
||||
'threading.py'
|
||||
]
|
||||
|
||||
PY_INSTALLDIR.install_sources(bavarder_sources, subdir: MODULE_DIR)
|
||||
|
||||
subdir('views')
|
||||
subdir('widgets')
|
||||
subdir('providers')
|
||||
subdir('widgets')
|
||||
@@ -1,5 +0,0 @@
|
||||
from .litert_lm import LiteRTLMProvider
|
||||
|
||||
PROVIDERS = [
|
||||
LiteRTLMProvider,
|
||||
]
|
||||
@@ -1,90 +0,0 @@
|
||||
import unicodedata
|
||||
import re
|
||||
from typing import List, Dict
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
from enum import Enum
|
||||
|
||||
class ProviderType(Enum):
|
||||
IMAGE = _("Image")
|
||||
CHAT = _("Chat")
|
||||
VOICE = _("Voice")
|
||||
TEXT = _("Text")
|
||||
MOVIE = _("Movie")
|
||||
class BaseProvider:
|
||||
name: str
|
||||
description: str = ""
|
||||
provider_type: ProviderType = ProviderType.CHAT
|
||||
languages: List[str] = []
|
||||
developer_name: str = "0xMRTT"
|
||||
developers = ["0xMRTT https://github.com/0xMRTT"]
|
||||
license_type = Gtk.License.GPL_3_0
|
||||
data: Dict[str, str] = {}
|
||||
has_auth: bool = False
|
||||
require_authentification: bool = False
|
||||
base_url = "https://bavarder.codeberg.page/providers/"
|
||||
|
||||
def __init__(self, app, window):
|
||||
self.slug = self.slugify(self.name)
|
||||
self.copyright = f"© 2023 {self.developer_name}"
|
||||
self.url = f"{self.base_url}{self.slug}"
|
||||
|
||||
self.app = app
|
||||
self.window = window
|
||||
|
||||
self.data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
try:
|
||||
return self.app.data["providers"][self.slug]["data"]
|
||||
except KeyError:
|
||||
self.app.data["providers"][self.slug] = {
|
||||
"enabled": False,
|
||||
"data": {
|
||||
|
||||
}
|
||||
}
|
||||
finally:
|
||||
return self.app.data["providers"][self.slug]["data"]
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.app.data["providers"][self.slug]["enabled"]
|
||||
|
||||
def set_enabled(self, status):
|
||||
self.app.data["providers"][self.slug]["enabled"] = status
|
||||
|
||||
def ask(self, prompt, chat):
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_authentification(self):
|
||||
"""Must set self.has_auth to True when auth is done"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_settings_rows(self) -> list:
|
||||
return []
|
||||
|
||||
# TOOLS
|
||||
def slugify(self, value):
|
||||
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
||||
value = re.sub('[^\w\s-]', '', value).strip().lower()
|
||||
return re.sub('[-\s]+', '-', value)
|
||||
|
||||
def chunk(self, prompt, n=4000):
|
||||
if len(prompt) > n:
|
||||
prompt = [(prompt[i : i + n]) for i in range(0, len(prompt), n)]
|
||||
return prompt
|
||||
|
||||
def open_documentation(self, *args, **kwargs):
|
||||
GLib.spawn_command_line_async(
|
||||
f"xdg-open {self.url}"
|
||||
)
|
||||
|
||||
def how_to_get_a_token(self):
|
||||
about_button = Gtk.Button()
|
||||
about_button.set_icon_name("dialog-information-symbolic")
|
||||
about_button.set_tooltip_text(_("How to get a token"))
|
||||
about_button.add_css_class("flat")
|
||||
about_button.set_valign(Gtk.Align.CENTER)
|
||||
about_button.connect("clicked", self.open_documentation)
|
||||
return about_button
|
||||
@@ -1,143 +0,0 @@
|
||||
import litert_lm
|
||||
from .base import BaseProvider, ProviderType
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
from huggingface_hub import hf_hub_download
|
||||
import os
|
||||
import threading
|
||||
|
||||
|
||||
class LiteRTLMProvider(BaseProvider):
|
||||
name = "LiteRT-LM"
|
||||
description = _("Run local LLMs using LiteRT-LM")
|
||||
provider_type = ProviderType.CHAT
|
||||
url = "https://ai.google.dev/edge/litert-lm"
|
||||
|
||||
def __init__(self, app, window):
|
||||
super().__init__(app, window)
|
||||
self.model = None
|
||||
self.conversation = None
|
||||
|
||||
def get_settings_rows(self):
|
||||
rows = []
|
||||
|
||||
self.hf_model_row = Adw.EntryRow()
|
||||
self.hf_model_row.connect("apply", self.on_apply)
|
||||
self.hf_model_row.props.title = _("HuggingFace Model (e.g., litert-community/gemma-4-E2B-it-litert-lm)")
|
||||
if "hf_model" in self.data:
|
||||
self.hf_model_row.props.text = str(self.data["hf_model"])
|
||||
else:
|
||||
self.hf_model_row.props.text = ""
|
||||
self.hf_model_row.set_show_apply_button(True)
|
||||
rows.append(self.hf_model_row)
|
||||
|
||||
self.download_button = Gtk.Button()
|
||||
self.download_button.set_label(_("Download Model"))
|
||||
self.download_button.connect("clicked", self.on_download_clicked)
|
||||
rows.append(self.download_button)
|
||||
|
||||
self.model_path_row = Adw.EntryRow()
|
||||
self.model_path_row.connect("apply", self.on_apply)
|
||||
self.model_path_row.props.title = _("Model Path (or leave empty to use HF model)")
|
||||
if "model_path" in self.data:
|
||||
self.model_path_row.props.text = str(self.data["model_path"])
|
||||
else:
|
||||
self.model_path_row.props.text = ""
|
||||
self.model_path_row.set_show_apply_button(True)
|
||||
rows.append(self.model_path_row)
|
||||
|
||||
return rows
|
||||
|
||||
def on_apply(self, widget):
|
||||
hf_model = self.hf_model_row.get_text()
|
||||
model_path = self.model_path_row.get_text()
|
||||
self.data["hf_model"] = hf_model
|
||||
self.data["model_path"] = model_path
|
||||
self.model = None
|
||||
self.conversation = None
|
||||
|
||||
def on_download_clicked(self, widget):
|
||||
def thread_run():
|
||||
try:
|
||||
hf_model = self.hf_model_row.get_text()
|
||||
if not hf_model:
|
||||
GLib.idle_add(self.show_error, _("Please enter a HuggingFace model ID"))
|
||||
return
|
||||
|
||||
toast = Adw.Toast()
|
||||
toast.set_timeout(0)
|
||||
toast.set_title(_("Downloading model from HuggingFace..."))
|
||||
GLib.idle_add(self.window.add_toast, toast)
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id=hf_model,
|
||||
filename="*.litertlm",
|
||||
cache_dir=self.app.user_cache_dir
|
||||
)
|
||||
|
||||
self.data["model_path"] = model_file
|
||||
GLib.idle_add(self.model_path_row.set_text, model_file)
|
||||
GLib.idle_add(toast.dismiss)
|
||||
|
||||
toast = Adw.Toast()
|
||||
toast.set_title(_("Model downloaded successfully!"))
|
||||
GLib.idle_add(self.window.add_toast, toast)
|
||||
|
||||
except Exception as e:
|
||||
GLib.idle_add(self.show_error, str(e))
|
||||
|
||||
t = threading.Thread(target=thread_run)
|
||||
t.start()
|
||||
|
||||
def show_error(self, message):
|
||||
toast = Adw.Toast()
|
||||
toast.set_title(message)
|
||||
self.window.add_toast(toast)
|
||||
|
||||
def get_model_path(self):
|
||||
model_path = self.data.get("model_path", "")
|
||||
if model_path and os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
hf_model = self.data.get("hf_model", "")
|
||||
if hf_model:
|
||||
try:
|
||||
return hf_hub_download(
|
||||
repo_id=hf_model,
|
||||
filename="*.litertlm",
|
||||
cache_dir=self.app.user_cache_dir
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def load_model(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
model_path = self.get_model_path()
|
||||
if not model_path:
|
||||
raise ValueError("No model available. Please download a model or set a model path.")
|
||||
|
||||
self.model = litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU)
|
||||
|
||||
def ask(self, prompt, chat):
|
||||
self.load_model()
|
||||
|
||||
messages = []
|
||||
for msg in chat["content"]:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
messages.append({
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": content}]
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}]
|
||||
})
|
||||
|
||||
with self.model.create_conversation(messages=messages) as conv:
|
||||
response = conv.send_message(prompt)
|
||||
return response["content"][0]["text"]
|
||||
@@ -1,10 +0,0 @@
|
||||
providers_dir = join_paths(MODULE_DIR, 'providers')
|
||||
|
||||
providers_sources = [
|
||||
'__init__.py',
|
||||
'base.py',
|
||||
'litert_lm.py',
|
||||
'provider_item.py',
|
||||
]
|
||||
|
||||
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
|
||||
@@ -1,21 +0,0 @@
|
||||
using Gtk 4.0;
|
||||
using Adw 1;
|
||||
|
||||
template $Provider : Adw.ExpanderRow {
|
||||
[suffix]
|
||||
Switch enable_switch {
|
||||
state-set => $on_switch_state_changed();
|
||||
valign: center;
|
||||
}
|
||||
|
||||
[suffix]
|
||||
Label provider_type {
|
||||
valign: center;
|
||||
styles [ "tag" ]
|
||||
}
|
||||
|
||||
Adw.ActionRow no_preferences_available {
|
||||
title: _("No preferences available");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
import unicodedata
|
||||
import re
|
||||
from typing import List, Dict
|
||||
from gi.repository import Gtk, Adw, GLib
|
||||
|
||||
from bavarder.constants import app_id, rootdir
|
||||
from .base import ProviderType
|
||||
|
||||
@Gtk.Template(resource_path=f"{rootdir}/ui/provider_item.ui")
|
||||
class Provider(Adw.ExpanderRow):
|
||||
__gtype_name__ = "Provider"
|
||||
|
||||
enable_switch = Gtk.Template.Child()
|
||||
no_preferences_available = Gtk.Template.Child()
|
||||
provider_type = Gtk.Template.Child()
|
||||
|
||||
def __init__(self, app, window, provider, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.app = app
|
||||
self.window = window
|
||||
self.provider = provider
|
||||
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.set_title(self.provider.name)
|
||||
self.set_subtitle(self.provider.description)
|
||||
self.provider_type.set_label(self.provider.provider_type.value)
|
||||
match self.provider.provider_type:
|
||||
case ProviderType.IMAGE:
|
||||
self.provider_type.add_css_class("badge-titanium")
|
||||
case ProviderType.CHAT:
|
||||
self.provider_type.add_css_class("badge-gold")
|
||||
case ProviderType.VOICE:
|
||||
self.provider_type.add_css_class("badge-iron")
|
||||
case ProviderType.TEXT:
|
||||
self.provider_type.add_css_class("badge-tin")
|
||||
case ProviderType.MOVIE:
|
||||
self.provider_type.add_css_class("badge-silver")
|
||||
|
||||
self.enable_switch.set_active( self.app.data["providers"][self.provider.slug]["enabled"])
|
||||
|
||||
if self.provider.get_settings_rows():
|
||||
self.no_preferences_available.set_visible(False)
|
||||
|
||||
for row in self.provider.get_settings_rows():
|
||||
self.add_row(row)
|
||||
|
||||
# CALLBACKS
|
||||
@Gtk.Template.Callback()
|
||||
def on_switch_state_changed(self, widget, _):
|
||||
self.provider.set_enabled(widget.get_active())
|
||||
self.app.win.load_provider_selector()
|
||||
|
||||
# TOOLS
|
||||
def slugify(self, value):
|
||||
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
||||
value = re.sub('[^\w\s-]', '', value).strip().lower()
|
||||
return re.sub('[-\s]+', '-', value)
|
||||
|
||||
def chunk(self, prompt, n=4000):
|
||||
if len(prompt) > n:
|
||||
prompt = [(prompt[i : i + n]) for i in range(0, len(prompt), n)]
|
||||
return prompt
|
||||
|
||||
@@ -11,10 +11,6 @@ template $Preferences : Adw.PreferencesWindow {
|
||||
title: _("Models");
|
||||
icon-name: "brain-augemnted-symbolic";
|
||||
|
||||
Adw.PreferencesGroup provider_group {
|
||||
title: _("Providers");
|
||||
}
|
||||
|
||||
Adw.PreferencesGroup model_group {
|
||||
title: _("Models");
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ from bavarder.widgets.model_item import Model
|
||||
class PreferencesWindow(Adw.PreferencesWindow):
|
||||
__gtype_name__ = "Preferences"
|
||||
|
||||
provider_group = Gtk.Template.Child()
|
||||
general_page = Gtk.Template.Child()
|
||||
model_group = Gtk.Template.Child()
|
||||
marketplace_group = Gtk.Template.Child()
|
||||
@@ -63,10 +62,8 @@ class PreferencesWindow(Adw.PreferencesWindow):
|
||||
try:
|
||||
from huggingface_hub import list_models
|
||||
models = list(list_models(
|
||||
filter={"library_name": "litert-lm"},
|
||||
author="litert-community",
|
||||
sort="downloads",
|
||||
direction=-1,
|
||||
limit=50
|
||||
))
|
||||
model_list = []
|
||||
for m in models:
|
||||
@@ -80,7 +77,11 @@ class PreferencesWindow(Adw.PreferencesWindow):
|
||||
GLib.idle_add(show_error, str(e))
|
||||
|
||||
def update_ui(model_list):
|
||||
self.marketplace_group.remove_all()
|
||||
child = self.marketplace_group.get_first_child()
|
||||
while child is not None:
|
||||
next_child = child.get_next_sibling()
|
||||
self.marketplace_group.remove(child)
|
||||
child = next_child
|
||||
for m in model_list:
|
||||
item = MarketplaceItem(self.app, self.win, m)
|
||||
self.marketplace_group.add(item)
|
||||
|
||||
@@ -97,12 +97,6 @@ template $BavarderWindow : Adw.ApplicationWindow {
|
||||
child: Adw.ToolbarView {
|
||||
[top]
|
||||
Adw.HeaderBar {
|
||||
[start]
|
||||
Gtk.ToggleButton local_mode_toggle {
|
||||
icon-name: 'cloud-disabled-symbolic';
|
||||
toggled => $on_local_mode_toggled();
|
||||
}
|
||||
|
||||
[title]
|
||||
Adw.WindowTitle title {
|
||||
title: _("Chat");
|
||||
@@ -118,13 +112,6 @@ template $BavarderWindow : Adw.ApplicationWindow {
|
||||
[end]
|
||||
MenuButton model_selector_button {
|
||||
icon-name: 'view-more-symbolic';
|
||||
visible: false;
|
||||
}
|
||||
|
||||
[end]
|
||||
MenuButton provider_selector_button {
|
||||
icon-name: 'view-more-symbolic';
|
||||
visible: false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+65
-138
@@ -18,9 +18,10 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
from datetime import datetime
|
||||
import locale
|
||||
import io
|
||||
import locale
|
||||
import io
|
||||
import base64
|
||||
import os
|
||||
|
||||
from gi.repository import Gtk, Gio, Adw, GLib
|
||||
from babel.dates import format_date, format_datetime, format_time
|
||||
@@ -50,8 +51,6 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
status_no_thread_main = Gtk.Template.Child()
|
||||
status_no_internet = Gtk.Template.Child()
|
||||
scrolled_window = Gtk.Template.Child()
|
||||
local_mode_toggle = Gtk.Template.Child()
|
||||
provider_selector_button = Gtk.Template.Child()
|
||||
model_selector_button = Gtk.Template.Child()
|
||||
banner = Gtk.Template.Child()
|
||||
toast_overlay = Gtk.Template.Child()
|
||||
@@ -82,10 +81,6 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
self.scrolled_window.set_child(self.message_entry)
|
||||
self.load_threads()
|
||||
|
||||
self.local_mode_toggle.set_active(self.app.local_mode)
|
||||
|
||||
self.on_local_mode_toggled(self.local_mode_toggle)
|
||||
|
||||
self.create_action("cancel", self.cancel, ["<primary>Escape"])
|
||||
self.create_action("clear_all", self.on_clear_all)
|
||||
self.create_action("export", self.on_export, ["<primary>e"])
|
||||
@@ -115,7 +110,7 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
except AttributeError: # create a new chat
|
||||
#self.on_new_chat_action()
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
@@ -186,7 +181,7 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
i += 1
|
||||
item = Item(self, self.chat, item)
|
||||
self.main_list.append(item)
|
||||
|
||||
|
||||
for i in range(i):
|
||||
row = self.main_list.get_row_at_index(i)
|
||||
row.set_selectable(False)
|
||||
@@ -257,93 +252,54 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
toast.set_title(_("Nothing to export!"))
|
||||
self.toast_overlay.add_toast(toast)
|
||||
|
||||
# PROVIDER - ONLINE
|
||||
def load_provider_selector(self):
|
||||
provider_menu = Gio.Menu()
|
||||
|
||||
section = Gio.Menu()
|
||||
for provider in self.app.providers.values():
|
||||
if provider.enabled:
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(provider.name)
|
||||
item_provider.set_action_and_target_value(
|
||||
"app.set_provider",
|
||||
GLib.Variant("s", provider.slug))
|
||||
section.append_item(item_provider)
|
||||
else:
|
||||
if self.app.providers:
|
||||
provider_menu.append_section(_("Providers"), section)
|
||||
section = Gio.Menu()
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Preferences"))
|
||||
item_provider.set_action_and_target_value("app.preferences", None)
|
||||
section.append_item(item_provider)
|
||||
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Clear all"))
|
||||
item_provider.set_action_and_target_value("win.clear_all", None)
|
||||
section.append_item(item_provider)
|
||||
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Export"))
|
||||
item_provider.set_action_and_target_value("win.export", None)
|
||||
section.append_item(item_provider)
|
||||
|
||||
provider_menu.append_section(None, section)
|
||||
|
||||
self.provider_selector_button.set_menu_model(provider_menu)
|
||||
|
||||
# MODEL - OFFLINE
|
||||
def load_model_selector(self):
|
||||
provider_menu = Gio.Menu()
|
||||
|
||||
if not self.app.models:
|
||||
self.app.list_models()
|
||||
section = Gio.Menu()
|
||||
|
||||
models = set()
|
||||
|
||||
model_path = self.app.data.get("models", {}).get("model_path", "")
|
||||
if model_path and os.path.exists(model_path):
|
||||
models.add(os.path.basename(model_path))
|
||||
|
||||
models_dir = os.path.join(self.app.user_cache_dir, "bavarder", "models")
|
||||
if os.path.exists(models_dir):
|
||||
for f in os.listdir(models_dir):
|
||||
if f.endswith(".litertlm"):
|
||||
models.add(f)
|
||||
|
||||
if models:
|
||||
for model in models:
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(model)
|
||||
item_provider.set_action_and_target_value(
|
||||
"app.set_model",
|
||||
GLib.Variant("s", model))
|
||||
section.append_item(item_provider)
|
||||
provider_menu.append_section(_("Models"), section)
|
||||
|
||||
section = Gio.Menu()
|
||||
for provider in self.app.models:
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(provider)
|
||||
item_provider.set_action_and_target_value(
|
||||
"app.set_model",
|
||||
GLib.Variant("s", provider))
|
||||
section.append_item(item_provider)
|
||||
else:
|
||||
if self.app.models:
|
||||
provider_menu.append_section(_("Models"), section)
|
||||
section = Gio.Menu()
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Preferences"))
|
||||
item_provider.set_action_and_target_value("app.preferences", None)
|
||||
section.append_item(item_provider)
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Preferences"))
|
||||
item_provider.set_action_and_target_value("app.preferences", None)
|
||||
section.append_item(item_provider)
|
||||
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Clear all"))
|
||||
item_provider.set_action_and_target_value("win.clear_all", None)
|
||||
section.append_item(item_provider)
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Clear all"))
|
||||
item_provider.set_action_and_target_value("win.clear_all", None)
|
||||
section.append_item(item_provider)
|
||||
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Export"))
|
||||
item_provider.set_action_and_target_value("win.export", None)
|
||||
section.append_item(item_provider)
|
||||
item_provider = Gio.MenuItem()
|
||||
item_provider.set_label(_("Export"))
|
||||
item_provider.set_action_and_target_value("win.export", None)
|
||||
section.append_item(item_provider)
|
||||
|
||||
provider_menu.append_section(None, section)
|
||||
provider_menu.append_section(None, section)
|
||||
|
||||
self.model_selector_button.set_menu_model(provider_menu)
|
||||
|
||||
@Gtk.Template.Callback()
|
||||
def on_local_mode_toggled(self, widget):
|
||||
self.app.local_mode = widget.get_active()
|
||||
|
||||
if self.app.local_mode:
|
||||
self.local_mode_toggle.set_icon_name("cloud-disabled-symbolic")
|
||||
self.model_selector_button.set_visible(True)
|
||||
self.provider_selector_button.set_visible(False)
|
||||
else:
|
||||
self.local_mode_toggle.set_icon_name("cloud-filled-symbolic")
|
||||
self.provider_selector_button.set_visible(True)
|
||||
self.model_selector_button.set_visible(False)
|
||||
|
||||
def check_network(self):
|
||||
if self.app.check_network(): # Internet
|
||||
if not self.content:
|
||||
@@ -367,50 +323,33 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
if not self.chat:
|
||||
self.on_new_chat_action()
|
||||
|
||||
# now get the latest row
|
||||
# now get the latest row
|
||||
row = self.threads_list.get_row_at_index(len(self.app.data["chats"]) - 1)
|
||||
|
||||
|
||||
|
||||
self.threads_list.select_row(row)
|
||||
self.threads_row_activated_cb()
|
||||
|
||||
|
||||
|
||||
self.add_user_item(prompt)
|
||||
|
||||
|
||||
def thread_run():
|
||||
self.toast = Adw.Toast()
|
||||
self.toast.set_title(_("Generating response"))
|
||||
self.toast.set_button_label(_("Cancel"))
|
||||
self.toast.set_action_name("win.cancel")
|
||||
self.toast.set_timeout(0)
|
||||
self.toast_overlay.add_toast(self.toast)
|
||||
response = self.app.ask(prompt, self.chat)
|
||||
GLib.idle_add(cleanup, response, self.toast)
|
||||
|
||||
def cleanup(response, toast):
|
||||
try:
|
||||
self.t.join()
|
||||
self.toast.dismiss()
|
||||
|
||||
if not response:
|
||||
self.add_assistant_item(_("Sorry, I don't know what to say."))
|
||||
else:
|
||||
if isinstance(response, str):
|
||||
self.add_assistant_item(response)
|
||||
else:
|
||||
buffered = io.BytesIO()
|
||||
response.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
|
||||
self.add_assistant_item(img_str.decode("utf-8"))
|
||||
|
||||
except AttributeError:
|
||||
self.toast.dismiss()
|
||||
def on_response(response):
|
||||
if not response:
|
||||
self.add_assistant_item(_("Sorry, I don't know what to say."))
|
||||
else:
|
||||
self.add_assistant_item(response)
|
||||
self.toast.dismiss()
|
||||
|
||||
self.t = KillableThread(target=thread_run)
|
||||
self.t.start()
|
||||
def on_error(error):
|
||||
self.toast.dismiss()
|
||||
self.add_assistant_item(_("Error: ") + error)
|
||||
|
||||
self.toast = Adw.Toast()
|
||||
self.toast.set_title(_("Generating response"))
|
||||
self.toast.set_timeout(0)
|
||||
self.toast_overlay.add_toast(self.toast)
|
||||
|
||||
self.app.ask(prompt, self.chat, on_response, on_error)
|
||||
|
||||
# @Gtk.Template.Callback()
|
||||
# def on_emoji(self, *args):
|
||||
@@ -437,7 +376,7 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
|
||||
if shortcuts:
|
||||
self.app.set_accels_for_action(f"win.{name}", shortcuts)
|
||||
|
||||
|
||||
def get_time(self):
|
||||
return format_time(datetime.now())
|
||||
|
||||
@@ -458,26 +397,14 @@ class BavarderWindow(Adw.ApplicationWindow):
|
||||
|
||||
def add_assistant_item(self, content):
|
||||
c = {
|
||||
"role": self.app.bot_name,
|
||||
"content": content,
|
||||
"time": self.get_time(),
|
||||
}
|
||||
"role": self.app.bot_name,
|
||||
"content": content,
|
||||
"time": self.get_time(),
|
||||
"model": "litert-lm",
|
||||
}
|
||||
|
||||
l = list(self.app.providers.values())
|
||||
|
||||
for p in l:
|
||||
if p.enabled and p.slug == self.app.current_provider:
|
||||
c["model"] = self.app.current_provider
|
||||
break
|
||||
else:
|
||||
c["model"] = "bavarder"
|
||||
|
||||
|
||||
self.content.append(c)
|
||||
|
||||
self.threads_row_activated_cb()
|
||||
|
||||
self.scroll_down()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -22,24 +22,40 @@ class MarketplaceItem(Adw.ActionRow):
|
||||
|
||||
@Gtk.Template.Callback()
|
||||
def on_download_button_clicked(self, widget, *args):
|
||||
self.download_cancelled = False
|
||||
|
||||
def thread_run():
|
||||
self.app.action_running_in_background = True
|
||||
|
||||
toast = Adw.Toast()
|
||||
toast.set_timeout(0)
|
||||
toast.set_title(_("Downloading model %s" % self.model_info.get("name")))
|
||||
self.window.add_toast(toast)
|
||||
self.window.toast_overlay.add_toast(toast)
|
||||
|
||||
model_id = self.model_info.get("id")
|
||||
from huggingface_hub import hf_hub_download
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="*.litertlm",
|
||||
cache_dir=self.app.user_cache_dir
|
||||
)
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
|
||||
files = list_repo_files(model_id, repo_type="model")
|
||||
litertlm_files = [f for f in files if f.endswith('.litertlm')]
|
||||
|
||||
if not litertlm_files:
|
||||
GLib.idle_add(show_error, _("No .litertlm file found in this model"))
|
||||
return
|
||||
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=litertlm_files[0],
|
||||
cache_dir=self.app.user_cache_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
GLib.idle_add(show_error, str(e))
|
||||
return
|
||||
|
||||
self.app.data["providers"]["litert-lm"]["data"]["model_path"] = model_file
|
||||
self.app.data["providers"]["litert-lm"]["data"]["hf_model"] = model_id
|
||||
if "models" not in self.app.data:
|
||||
self.app.data["models"] = {}
|
||||
self.app.data["models"]["model_path"] = model_file
|
||||
self.app.data["models"]["hf_model"] = model_id
|
||||
GLib.idle_add(cleanup, toast, model_file)
|
||||
|
||||
def cleanup(toast, model_file):
|
||||
@@ -51,7 +67,19 @@ class MarketplaceItem(Adw.ActionRow):
|
||||
|
||||
toast = Adw.Toast()
|
||||
toast.set_title(_("Model %s downloaded!" % self.model_info.get("name")))
|
||||
self.window.add_toast(toast)
|
||||
self.window.toast_overlay.add_toast(toast)
|
||||
|
||||
self.set_subtitle(self.model_info.get("id"))
|
||||
|
||||
def show_error(message):
|
||||
try:
|
||||
t.join()
|
||||
except Exception:
|
||||
pass
|
||||
self.app.action_running_in_background = False
|
||||
toast = Adw.Toast()
|
||||
toast.set_title(message)
|
||||
self.window.toast_overlay.add_toast(toast)
|
||||
|
||||
t = KillableThread(target=thread_run)
|
||||
t.start()
|
||||
|
||||
Reference in New Issue
Block a user