This commit is contained in:
parent
54ef8641fc
commit
9362ec4563
@ -3,6 +3,8 @@ import os
|
||||
import sys
|
||||
from baichat_py import Completion
|
||||
|
||||
import traceback
|
||||
|
||||
PREFIX = "!"
|
||||
|
||||
USERNAME = os.environ.get("MATRIX_USERNAME", "ai")
|
||||
@ -542,14 +544,15 @@ def validate_cfg(cfg: float) -> str:
|
||||
class AsyncImagine:
|
||||
"""Async class for handling API requests to the Imagine service."""
|
||||
|
||||
HEADERS = {"accept": "*/*", "user-agent": "okhttp/4.10.0"}
|
||||
HEADERS = {
|
||||
"accept": "*/*",
|
||||
"user-agent": "okhttp/4.10.0"
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.asset = "https://1966211409.rsc.cdn77.org"
|
||||
self.api = "https://inferenceengine.vyro.ai"
|
||||
self.session = aiohttp.ClientSession(
|
||||
raise_for_status=True, headers=self.HEADERS
|
||||
)
|
||||
self.session = aiohttp.ClientSession(raise_for_status=True, headers=self.HEADERS)
|
||||
self.version = "1"
|
||||
|
||||
async def close(self) -> None:
|
||||
@ -568,149 +571,91 @@ class AsyncImagine:
|
||||
|
||||
async def assets(self, style: Style = Style.IMAGINE_V1) -> bytes:
|
||||
"""Gets the assets."""
|
||||
async with self.session.get(url=self.get_style_url(style=style)) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def variate(
|
||||
self, image: bytes, prompt: str, style: Style = Style.IMAGINE_V1
|
||||
) -> bytes:
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/variate",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"prompt": prompt + (style.value[3] or ""),
|
||||
"strength": "0",
|
||||
"style_id": str(style.value[0]),
|
||||
"image": self.bytes_to_io(image, "image.png"),
|
||||
},
|
||||
async with self.session.get(
|
||||
url=self.get_style_url(style=style)
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def sdprem(
|
||||
self,
|
||||
prompt: str,
|
||||
negative: str = None,
|
||||
priority: str = None,
|
||||
steps: str = None,
|
||||
high_res_results: str = None,
|
||||
style: Style = Style.IMAGINE_V1,
|
||||
seed: str = None,
|
||||
ratio: Ratio = Ratio.RATIO_1X1,
|
||||
cfg: float = 9.5,
|
||||
) -> bytes:
|
||||
async def sdprem(self, prompt: str, negative: str | bool = None, priority: str = None, steps: str = None,
|
||||
high_res_results: str = None, style: Style = Style.IMAGINE_V1, seed: str = None,
|
||||
ratio: Ratio = Ratio.RATIO_1X1, cfg: float = 9.5) -> bytes | None:
|
||||
"""Generates AI Art."""
|
||||
try:
|
||||
validated_cfg = validate_cfg(cfg)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while validating cfg: {e}")
|
||||
traceback.print_exc() # Print the full traceback for detailed debugging
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/sdprem",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"prompt": prompt + (style.value[3] or ""),
|
||||
"negative_prompt": negative or "",
|
||||
"style_id": style.value[0],
|
||||
"width": ratio.value[0],
|
||||
"height": ratio.value[1],
|
||||
"seed": seed or "",
|
||||
"steps": steps or "30",
|
||||
"cfg": validated_cfg,
|
||||
"priority": priority or "0",
|
||||
"high_res_results": high_res_results or "0",
|
||||
},
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
except Exception as e:
|
||||
print(f"An error occurred while making the request: {e}")
|
||||
return None
|
||||
for attempt in range(2):
|
||||
try:
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/sdprem",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"prompt": prompt + (style.value[3] or ""),
|
||||
"negative_prompt": negative or "ugly, disfigured, low quality, blurry, nsfw",
|
||||
"style_id": style.value[0],
|
||||
"width": ratio.value[0],
|
||||
"height": ratio.value[1],
|
||||
"seed": seed or "",
|
||||
"steps": steps or "30",
|
||||
"cfg": validated_cfg,
|
||||
"priority": priority or "0",
|
||||
"high_res_results": high_res_results or "0"
|
||||
}
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
except Exception as e:
|
||||
print(f"An error occurred while making the request: {e}")
|
||||
traceback.print_exc() # Print the full traceback for detailed debugging
|
||||
if attempt == 0:
|
||||
await asyncio.sleep(0.4)
|
||||
print("Retrying....")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def upscale(self, image: bytes) -> bytes:
|
||||
async def upscale(self, image: bytes) -> bytes | None:
|
||||
"""Upscales the image."""
|
||||
try:
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/upscale",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"image": self.bytes_to_io(image, "test.png"),
|
||||
},
|
||||
url=f"{self.api}/upscale",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"image": self.bytes_to_io(image, "test.png")
|
||||
}
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
except Exception as e:
|
||||
print(f"An error occurred while making the request: {e}")
|
||||
return None
|
||||
|
||||
async def translate(self, prompt: str) -> str:
|
||||
"""Translates the prompt."""
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/translate",
|
||||
data={"q": prompt, "source": detect(prompt), "target": "en"},
|
||||
) as resp:
|
||||
return (await resp.json())["translatedText"]
|
||||
|
||||
async def interrogator(self, image: bytes) -> str:
|
||||
"""Generates a prompt."""
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/interrogator",
|
||||
data={
|
||||
"model_version": str(self.version),
|
||||
"image": self.bytes_to_io(image, "prompt_generator_temp.png"),
|
||||
},
|
||||
url=f"{self.api}/interrogator",
|
||||
data={
|
||||
"model_version": str(self.version),
|
||||
"image": self.bytes_to_io(image, "prompt_generator_temp.png")
|
||||
}
|
||||
) as resp:
|
||||
return await resp.text()
|
||||
|
||||
async def sdimg(
|
||||
self,
|
||||
image: bytes,
|
||||
prompt: str,
|
||||
negative: str = None,
|
||||
seed: str = None,
|
||||
cfg: float = 9.5,
|
||||
) -> bytes:
|
||||
async def sdimg(self, image: bytes, prompt: str, negative: str = None, seed: str = None, cfg: float = 9.5) -> bytes:
|
||||
"""Performs inpainting."""
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/sdimg",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative or "",
|
||||
"seed": seed or "",
|
||||
"cfg": validate_cfg(cfg),
|
||||
"image": self.bytes_to_io(image, "image.png"),
|
||||
},
|
||||
url=f"{self.api}/sdimg",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative or "",
|
||||
"seed": seed or "",
|
||||
"cfg": validate_cfg(cfg),
|
||||
"image": self.bytes_to_io(image, "image.png")
|
||||
}
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def controlnet(
|
||||
self,
|
||||
image: bytes,
|
||||
prompt: str,
|
||||
negative: str = None,
|
||||
cfg: float = 9.5,
|
||||
control: Control = Control.SCRIBBLE,
|
||||
style: Style = Style.IMAGINE_V1,
|
||||
seed: str = None,
|
||||
) -> bytes:
|
||||
"""Performs image remix."""
|
||||
async with self.session.post(
|
||||
url=f"{self.api}/controlnet",
|
||||
data={
|
||||
"model_version": self.version,
|
||||
"prompt": prompt + (style.value[3] or ""),
|
||||
"negative_prompt": negative or "",
|
||||
"strength": "0",
|
||||
"cfg": validate_cfg(cfg),
|
||||
"control": control.value,
|
||||
"style_id": str(style.value[0]),
|
||||
"seed": seed or "",
|
||||
"image": self.bytes_to_io(image, "image.png"),
|
||||
},
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
|
||||
def run():
|
||||
if not USERNAME or not SERVER or not PASSWORD:
|
||||
print(
|
||||
@ -777,17 +722,12 @@ def run():
|
||||
else:
|
||||
prompt += arg + " "
|
||||
|
||||
async def generate_image(
|
||||
image_prompt, style_value, ratio_value, negative
|
||||
):
|
||||
async def generate_image(image_prompt, style_value, ratio_value, negative, upscale):
|
||||
if negative is None:
|
||||
negative = False
|
||||
imagine = AsyncImagine()
|
||||
filename = str(uuid.uuid4()) + ".png"
|
||||
try:
|
||||
style_enum = Style[style_value]
|
||||
ratio_enum = Ratio[ratio_value]
|
||||
except KeyError:
|
||||
style_enum = Style.IMAGINE_V3
|
||||
ratio_enum = Ratio.RATIO_1X1
|
||||
style_enum = Style[style_value]
|
||||
ratio_enum = Ratio[ratio_value]
|
||||
img_data = await imagine.sdprem(
|
||||
prompt=image_prompt,
|
||||
style=style_enum,
|
||||
@ -795,21 +735,23 @@ def run():
|
||||
priority="1",
|
||||
high_res_results="1",
|
||||
steps="70",
|
||||
negative=negative,
|
||||
negative=negative
|
||||
)
|
||||
|
||||
if upscale:
|
||||
img_data = await imagine.upscale(image=img_data)
|
||||
|
||||
try:
|
||||
with open(filename, mode="wb") as img_file:
|
||||
img_file.write(img_data)
|
||||
img_file = io.BytesIO(img_data)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while writing the image to file: {e}")
|
||||
print(
|
||||
f"An error occurred while creating the in-memory image file: {e}")
|
||||
return None
|
||||
|
||||
await imagine.close()
|
||||
return img_file
|
||||
|
||||
return filename
|
||||
|
||||
filename = await generate_image(prompt, style, ratio, negative)
|
||||
filename = await generate_image(prompt, style, ratio, negative, upscale=False)
|
||||
|
||||
await bot.api.send_image_message(
|
||||
room_id=room.room_id, image_filepath=filename
|
||||
|
Loading…
Reference in New Issue
Block a user