267 lines
9.1 KiB
Python
267 lines
9.1 KiB
Python
|
import asyncio
|
|||
|
import base64
|
|||
|
import random
|
|||
|
import time
|
|||
|
from io import BytesIO
|
|||
|
|
|||
|
import aiohttp
|
|||
|
from nonebot import get_driver
|
|||
|
from nonebot.log import logger
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
from ..config import config
|
|||
|
from ..utils import png2jpg
|
|||
|
from ..utils.data import shapemap
|
|||
|
|
|||
|
|
|||
|
class AIDRAW_BASE:
|
|||
|
max_resolution: int = 16
|
|||
|
sampler: str
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
user_id: str,
|
|||
|
group_id: str,
|
|||
|
tags: str = "",
|
|||
|
ntags: str = "",
|
|||
|
seed: int = None,
|
|||
|
scale: int = None,
|
|||
|
steps: int = None,
|
|||
|
batch: int = None,
|
|||
|
strength: float = None,
|
|||
|
noise: float = None,
|
|||
|
shape: str = "p",
|
|||
|
model: str = None,
|
|||
|
sampler: str = "",
|
|||
|
hires: bool = False,
|
|||
|
**kwargs,
|
|||
|
):
|
|||
|
"""
|
|||
|
AI绘画的核心部分,将与服务器通信的过程包装起来,并方便扩展服务器类型
|
|||
|
|
|||
|
:user_id: 用户id,必须
|
|||
|
:group_id: 群聊id,如果是私聊则应置为0,必须
|
|||
|
:tags: 图像的标签
|
|||
|
:ntags: 图像的反面标签
|
|||
|
:seed: 生成的种子,不指定的情况下随机生成
|
|||
|
:scale: 标签的参考度,值越高越贴近于标签,但可能产生过度锐化。范围为0-30,默认11
|
|||
|
:steps: 训练步数。范围为1-50,默认28.以图生图时强制50
|
|||
|
:batch: 同时生成数量
|
|||
|
:strength: 以图生图时使用,变化的强度。范围为0-1,默认0.7
|
|||
|
:noise: 以图生图时使用,变化的噪音,数值越大细节越多,但可能产生伪影,不建议超过strength。范围0-1,默认0.2
|
|||
|
:shape: 图像的形状,支持"p""s""l"三种,同时支持以"x"分割宽高的指定分辨率。
|
|||
|
该值会被设置限制,并不会严格遵守输入
|
|||
|
类初始化后,该参数会被拆分为:width:和:height:
|
|||
|
:model: 指定的模型,模型名称在配置文件中手动设置。不指定模型则按照负载均衡自动选择
|
|||
|
|
|||
|
AIDRAW还包含了以下几种内置的参数
|
|||
|
:status: 记录了AIDRAW的状态,默认为0等待中(处理中)
|
|||
|
非0的值为运行完成后的状态值,200和201为正常运行,其余值为产生错误
|
|||
|
:result: 当正常运行完成后,该参数为一个包含了生成图片bytes信息的数组
|
|||
|
:maxresolution: 一般不用管,用于限制不同服务器的最大分辨率
|
|||
|
如果你的SD经过了魔改支持更大分辨率可以修改该值并重新设置宽高
|
|||
|
:cost: 记录了本次生成需要花费多少点数,自动计算
|
|||
|
:signal: asyncio.Event类,可以作为信号使用。仅占位,需要自行实现相关方法
|
|||
|
"""
|
|||
|
self.status: int = 0
|
|||
|
self.result: list = []
|
|||
|
self.signal: asyncio.Event = None
|
|||
|
self.model = model
|
|||
|
self.time = time.strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
self.user_id: str = user_id
|
|||
|
self.tags: str = tags
|
|||
|
self.seed: list[int] = [seed or random.randint(0, 4294967295)]
|
|||
|
self.group_id: str = group_id
|
|||
|
self.scale: int = int(scale or 11)
|
|||
|
self.strength: float = strength or 0.7
|
|||
|
self.batch: int = batch or 1
|
|||
|
self.steps: int = steps or 28
|
|||
|
self.noise: float = noise or 0.2
|
|||
|
self.ntags: str = ntags
|
|||
|
self.img2img: bool = False
|
|||
|
self.image: str = None
|
|||
|
self.width, self.height = self.extract_shape(shape)
|
|||
|
self.sampler = sampler
|
|||
|
self.hires = hires
|
|||
|
# 数值合法检查
|
|||
|
if self.steps <= 0 or self.steps > (50 if not config.novelai_paid else 28):
|
|||
|
self.steps = 28
|
|||
|
if self.strength < 0 or self.strength > 1:
|
|||
|
self.strength = 0.7
|
|||
|
if self.noise < 0 or self.noise > 1:
|
|||
|
self.noise = 0.2
|
|||
|
if self.scale <= 0 or self.scale > 30:
|
|||
|
self.scale = 11
|
|||
|
# 多图时随机填充剩余seed
|
|||
|
for i in range(self.batch - 1):
|
|||
|
self.seed.append(random.randint(0, 4294967295))
|
|||
|
# 计算cost
|
|||
|
self.update_cost()
|
|||
|
|
|||
|
def extract_shape(self, shape: str):
|
|||
|
"""
|
|||
|
将shape拆分为width和height
|
|||
|
"""
|
|||
|
if shape:
|
|||
|
if "x" in shape:
|
|||
|
width, height, *_ = shape.split("x")
|
|||
|
if width.isdigit() and height.isdigit():
|
|||
|
return self.shape_set(int(width), int(height))
|
|||
|
else:
|
|||
|
return shapemap.get(shape)
|
|||
|
else:
|
|||
|
return shapemap.get(shape)
|
|||
|
else:
|
|||
|
return (512, 768)
|
|||
|
|
|||
|
def update_cost(self):
|
|||
|
"""
|
|||
|
更新cost
|
|||
|
"""
|
|||
|
if config.novelai_paid == 1:
|
|||
|
anlas = 0
|
|||
|
if (self.width * self.height > 409600) or self.image or self.batch > 1:
|
|||
|
anlas = round(
|
|||
|
self.width
|
|||
|
* self.height
|
|||
|
* self.strength
|
|||
|
* self.batch
|
|||
|
* self.steps
|
|||
|
/ 2293750
|
|||
|
)
|
|||
|
if anlas < 2:
|
|||
|
anlas = 2
|
|||
|
if self.user_id in get_driver().config.superusers:
|
|||
|
self.cost = 0
|
|||
|
else:
|
|||
|
self.cost = anlas
|
|||
|
elif config.novelai_paid == 2:
|
|||
|
anlas = round(
|
|||
|
self.width
|
|||
|
* self.height
|
|||
|
* self.strength
|
|||
|
* self.batch
|
|||
|
* self.steps
|
|||
|
/ 2293750
|
|||
|
)
|
|||
|
if anlas < 2:
|
|||
|
anlas = 2
|
|||
|
if self.user_id in get_driver().config.superusers:
|
|||
|
self.cost = 0
|
|||
|
else:
|
|||
|
self.cost = anlas
|
|||
|
else:
|
|||
|
self.cost = 0
|
|||
|
|
|||
|
def add_image(self, image: bytes):
|
|||
|
"""
|
|||
|
向类中添加图片,将其转化为以图生图模式
|
|||
|
也可用于修改类中已经存在的图片
|
|||
|
"""
|
|||
|
# 根据图片重写长宽
|
|||
|
tmpfile = BytesIO(image)
|
|||
|
image_ = Image.open(tmpfile)
|
|||
|
width, height = image_.size
|
|||
|
self.width, self.height = self.shape_set(width, height)
|
|||
|
self.image = str(base64.b64encode(image), "utf-8")
|
|||
|
self.steps = 50
|
|||
|
self.img2img = True
|
|||
|
self.update_cost()
|
|||
|
|
|||
|
def shape_set(self, width: int, height: int):
|
|||
|
"""
|
|||
|
设置宽高
|
|||
|
"""
|
|||
|
limit = 1024 if config.paid else 640
|
|||
|
if width * height > pow(min(config.novelai_size, limit), 2):
|
|||
|
if width <= height:
|
|||
|
ratio = height / width
|
|||
|
width: float = config.novelai_size / pow(ratio, 0.5)
|
|||
|
height: float = width * ratio
|
|||
|
else:
|
|||
|
ratio = width / height
|
|||
|
height: float = config.novelai_size / pow(ratio, 0.5)
|
|||
|
width: float = height * ratio
|
|||
|
base = round(max(width, height) / 64)
|
|||
|
if base > self.max_resolution:
|
|||
|
base = self.max_resolution
|
|||
|
if width <= height:
|
|||
|
return (round(width / height * base) * 64, 64 * base)
|
|||
|
else:
|
|||
|
return (64 * base, round(height / width * base) * 64)
|
|||
|
|
|||
|
async def post_(self, header: dict, post_api: str, json: dict):
|
|||
|
"""
|
|||
|
向服务器发送请求的核心函数,不要直接调用,请使用post函数
|
|||
|
:header: 请求头
|
|||
|
:post_api: 请求地址
|
|||
|
:json: 请求体
|
|||
|
"""
|
|||
|
# 请求交互
|
|||
|
async with aiohttp.ClientSession(headers=header) as session:
|
|||
|
# 向服务器发送请求
|
|||
|
async with session.post(post_api, json=json) as resp:
|
|||
|
if resp.status not in [200, 201]:
|
|||
|
logger.error(await resp.text())
|
|||
|
raise RuntimeError(f"与服务器沟通时发生{resp.status}错误")
|
|||
|
img = await self.fromresp(resp)
|
|||
|
logger.debug(f"获取到返回图片,正在处理")
|
|||
|
|
|||
|
# 将图片转化为jpg
|
|||
|
if config.novelai_save == 1:
|
|||
|
image_new = await png2jpg(img)
|
|||
|
else:
|
|||
|
image_new = base64.b64decode(img)
|
|||
|
self.result.append(image_new)
|
|||
|
return image_new
|
|||
|
|
|||
|
async def fromresp(self, resp):
|
|||
|
"""
|
|||
|
处理请求的返回内容,不要直接调用,请使用post函数
|
|||
|
"""
|
|||
|
img: str = await resp.text()
|
|||
|
return img.split("data:")[1]
|
|||
|
|
|||
|
def run(self):
|
|||
|
"""
|
|||
|
运行核心函数,发送请求并处理
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def keys(self):
|
|||
|
return (
|
|||
|
"seed",
|
|||
|
"scale",
|
|||
|
"strength",
|
|||
|
"noise",
|
|||
|
"sampler",
|
|||
|
"model",
|
|||
|
"steps",
|
|||
|
"width",
|
|||
|
"height",
|
|||
|
"img2img",
|
|||
|
)
|
|||
|
|
|||
|
def __getitem__(self, item):
|
|||
|
return getattr(self, item)
|
|||
|
|
|||
|
def format(self):
|
|||
|
dict_self = dict(self)
|
|||
|
list = []
|
|||
|
str = ""
|
|||
|
for i, v in dict_self.items():
|
|||
|
str += f"{i}={v}\n"
|
|||
|
list.append(str)
|
|||
|
list.append(f"tags={self.tags}\n")
|
|||
|
list.append(f"ntags={self.ntags}")
|
|||
|
return list
|
|||
|
|
|||
|
def __repr__(self):
|
|||
|
return (
|
|||
|
f"time={self.time}\nuser_id={self.user_id}\ngroup_id={self.group_id}\ncost={self.cost}\nbatch={self.batch}\n"
|
|||
|
+ "".join(self.format())
|
|||
|
)
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
return self.__repr__().replace("\n", ";")
|