新增指定sampler, hires的功能

修复tags中方括号编码不正常的问题
This commit is contained in:
2023-03-01 21:14:18 +08:00
commit deae8e4888
29 changed files with 1503 additions and 0 deletions

20
backend/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
from ..config import config
"""def AIDRAW():
if config.novelai_mode=="novelai":
from .novelai import AIDRAW
elif config.novelai_mode=="naifu":
from .naifu import AIDRAW
elif config.novelai_mode=="sd":
from .sd import AIDRAW
else:
raise RuntimeError(f"错误的mode设置支持的字符串为'novelai','naifu','sd'")
return AIDRAW()"""
if config.novelai_mode=="novelai":
from .novelai import AIDRAW
elif config.novelai_mode=="naifu":
from .naifu import AIDRAW
elif config.novelai_mode=="sd":
from .sd import AIDRAW
else:
raise RuntimeError(f"错误的mode设置支持的字符串为'novelai','naifu','sd'")

266
backend/base.py Normal file
View File

@@ -0,0 +1,266 @@
import asyncio
import base64
import random
import time
from io import BytesIO
import aiohttp
from nonebot import get_driver
from nonebot.log import logger
from PIL import Image
from ..config import config
from ..utils import png2jpg
from ..utils.data import shapemap
class AIDRAW_BASE:
max_resolution: int = 16
sampler: str
def __init__(
self,
user_id: str,
group_id: str,
tags: str = "",
ntags: str = "",
seed: int = None,
scale: int = None,
steps: int = None,
batch: int = None,
strength: float = None,
noise: float = None,
shape: str = "p",
model: str = None,
sampler: str = "",
hires: bool = False,
**kwargs,
):
"""
AI绘画的核心部分,将与服务器通信的过程包装起来,并方便扩展服务器类型
:user_id: 用户id,必须
:group_id: 群聊id,如果是私聊则应置为0,必须
:tags: 图像的标签
:ntags: 图像的反面标签
:seed: 生成的种子,不指定的情况下随机生成
:scale: 标签的参考度,值越高越贴近于标签,但可能产生过度锐化。范围为0-30,默认11
:steps: 训练步数。范围为1-50,默认28.以图生图时强制50
:batch: 同时生成数量
:strength: 以图生图时使用,变化的强度。范围为0-1,默认0.7
:noise: 以图生图时使用,变化的噪音,数值越大细节越多,但可能产生伪影,不建议超过strength。范围0-1,默认0.2
:shape: 图像的形状,支持"p""s""l"三种,同时支持以"x"分割宽高的指定分辨率。
该值会被设置限制,并不会严格遵守输入
类初始化后,该参数会被拆分为:width:和:height:
:model: 指定的模型,模型名称在配置文件中手动设置。不指定模型则按照负载均衡自动选择
AIDRAW还包含了以下几种内置的参数
:status: 记录了AIDRAW的状态,默认为0等待中(处理中)
非0的值为运行完成后的状态值,200和201为正常运行,其余值为产生错误
:result: 当正常运行完成后,该参数为一个包含了生成图片bytes信息的数组
:maxresolution: 一般不用管,用于限制不同服务器的最大分辨率
如果你的SD经过了魔改支持更大分辨率可以修改该值并重新设置宽高
:cost: 记录了本次生成需要花费多少点数,自动计算
:signal: asyncio.Event类,可以作为信号使用。仅占位,需要自行实现相关方法
"""
self.status: int = 0
self.result: list = []
self.signal: asyncio.Event = None
self.model = model
self.time = time.strftime("%Y-%m-%d %H:%M:%S")
self.user_id: str = user_id
self.tags: str = tags
self.seed: list[int] = [seed or random.randint(0, 4294967295)]
self.group_id: str = group_id
self.scale: int = int(scale or 11)
self.strength: float = strength or 0.7
self.batch: int = batch or 1
self.steps: int = steps or 28
self.noise: float = noise or 0.2
self.ntags: str = ntags
self.img2img: bool = False
self.image: str = None
self.width, self.height = self.extract_shape(shape)
self.sampler = sampler
self.hires = hires
# 数值合法检查
if self.steps <= 0 or self.steps > (50 if not config.novelai_paid else 28):
self.steps = 28
if self.strength < 0 or self.strength > 1:
self.strength = 0.7
if self.noise < 0 or self.noise > 1:
self.noise = 0.2
if self.scale <= 0 or self.scale > 30:
self.scale = 11
# 多图时随机填充剩余seed
for i in range(self.batch - 1):
self.seed.append(random.randint(0, 4294967295))
# 计算cost
self.update_cost()
def extract_shape(self, shape: str):
"""
将shape拆分为width和height
"""
if shape:
if "x" in shape:
width, height, *_ = shape.split("x")
if width.isdigit() and height.isdigit():
return self.shape_set(int(width), int(height))
else:
return shapemap.get(shape)
else:
return shapemap.get(shape)
else:
return (512, 768)
def update_cost(self):
"""
更新cost
"""
if config.novelai_paid == 1:
anlas = 0
if (self.width * self.height > 409600) or self.image or self.batch > 1:
anlas = round(
self.width
* self.height
* self.strength
* self.batch
* self.steps
/ 2293750
)
if anlas < 2:
anlas = 2
if self.user_id in get_driver().config.superusers:
self.cost = 0
else:
self.cost = anlas
elif config.novelai_paid == 2:
anlas = round(
self.width
* self.height
* self.strength
* self.batch
* self.steps
/ 2293750
)
if anlas < 2:
anlas = 2
if self.user_id in get_driver().config.superusers:
self.cost = 0
else:
self.cost = anlas
else:
self.cost = 0
def add_image(self, image: bytes):
"""
向类中添加图片,将其转化为以图生图模式
也可用于修改类中已经存在的图片
"""
# 根据图片重写长宽
tmpfile = BytesIO(image)
image_ = Image.open(tmpfile)
width, height = image_.size
self.width, self.height = self.shape_set(width, height)
self.image = str(base64.b64encode(image), "utf-8")
self.steps = 50
self.img2img = True
self.update_cost()
def shape_set(self, width: int, height: int):
"""
设置宽高
"""
limit = 1024 if config.paid else 640
if width * height > pow(min(config.novelai_size, limit), 2):
if width <= height:
ratio = height / width
width: float = config.novelai_size / pow(ratio, 0.5)
height: float = width * ratio
else:
ratio = width / height
height: float = config.novelai_size / pow(ratio, 0.5)
width: float = height * ratio
base = round(max(width, height) / 64)
if base > self.max_resolution:
base = self.max_resolution
if width <= height:
return (round(width / height * base) * 64, 64 * base)
else:
return (64 * base, round(height / width * base) * 64)
async def post_(self, header: dict, post_api: str, json: dict):
"""
向服务器发送请求的核心函数不要直接调用请使用post函数
:header: 请求头
:post_api: 请求地址
:json: 请求体
"""
# 请求交互
async with aiohttp.ClientSession(headers=header) as session:
# 向服务器发送请求
async with session.post(post_api, json=json) as resp:
if resp.status not in [200, 201]:
logger.error(await resp.text())
raise RuntimeError(f"与服务器沟通时发生{resp.status}错误")
img = await self.fromresp(resp)
logger.debug(f"获取到返回图片,正在处理")
# 将图片转化为jpg
if config.novelai_save == 1:
image_new = await png2jpg(img)
else:
image_new = base64.b64decode(img)
self.result.append(image_new)
return image_new
async def fromresp(self, resp):
"""
处理请求的返回内容不要直接调用请使用post函数
"""
img: str = await resp.text()
return img.split("data:")[1]
def run(self):
"""
运行核心函数,发送请求并处理
"""
pass
def keys(self):
return (
"seed",
"scale",
"strength",
"noise",
"sampler",
"model",
"steps",
"width",
"height",
"img2img",
)
def __getitem__(self, item):
return getattr(self, item)
def format(self):
dict_self = dict(self)
list = []
str = ""
for i, v in dict_self.items():
str += f"{i}={v}\n"
list.append(str)
list.append(f"tags={self.tags}\n")
list.append(f"ntags={self.ntags}")
return list
def __repr__(self):
return (
f"time={self.time}\nuser_id={self.user_id}\ngroup_id={self.group_id}\ncost={self.cost}\nbatch={self.batch}\n"
+ "".join(self.format())
)
def __str__(self):
return self.__repr__().replace("\n", ";")

34
backend/naifu.py Normal file
View File

@@ -0,0 +1,34 @@
from .base import AIDRAW_BASE
from ..config import config
class AIDRAW(AIDRAW_BASE):
"""队列中的单个请求"""
async def post(self):
header = {
"content-type": "application/json",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
}
site=config.novelai_site or "127.0.0.1:6969"
post_api="http://"+site + "/generate-stream"
for i in range(self.batch):
parameters = {
"prompt":self.tags,
"width": self.width,
"height": self.height,
"qualityToggle": False,
"scale": self.scale,
"sampler": self.sampler,
"steps": self.steps,
"seed": self.seed[i],
"n_samples": 1,
"ucPreset": 0,
"uc": self.ntags,
}
if self.img2img:
parameters.update({
"image": self.image,
"strength": self.strength,
"noise": self.noise
})
await self.post_(header, post_api,parameters)
return self.result

44
backend/novelai.py Normal file
View File

@@ -0,0 +1,44 @@
from ..config import config
from .base import AIDRAW_BASE
class AIDRAW(AIDRAW_BASE):
"""队列中的单个请求"""
model: str = "nai-diffusion" if config.novelai_h else "safe-diffusion"
async def post(self):
# 获取请求体
header = {
"authorization": "Bearer " + config.novelai_token,
":authority": "https://api.novelai.net",
":path": "/ai/generate-image",
"content-type": "application/json",
"referer": "https://novelai.net",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
}
post_api = "https://api.novelai.net/ai/generate-image"
for i in range(self.batch):
parameters = {
"width": self.width,
"height": self.height,
"qualityToggle": False,
"scale": self.scale,
"sampler": self.sampler,
"steps": self.steps,
"seed": self.seed[i],
"n_samples": 1,
"ucPreset": 0,
"uc": self.ntags,
}
if self.img2img:
parameters.update({
"image": self.image,
"strength": self.strength,
"noise": self.noise
})
json= {
"input": self.tags,
"model": self.model,
"parameters": parameters
}
await self.post_(header, post_api,json)
return self.result

43
backend/sd.py Normal file
View File

@@ -0,0 +1,43 @@
from .base import AIDRAW_BASE
from ..config import config
class AIDRAW(AIDRAW_BASE):
"""队列中的单个请求"""
max_resolution: int = 32
async def fromresp(self, resp):
img: dict = await resp.json()
return img["images"][0]
async def post(self):
site=config.novelai_site or "127.0.0.1:7860"
header = {
"content-type": "application/json",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
}
post_api = f"http://{site}/sdapi/v1/img2img" if self.img2img else f"http://{site}/sdapi/v1/txt2img"
for i in range(self.batch):
parameters = {
"prompt": self.tags,
"seed": self.seed[i],
"steps": self.steps,
"cfg_scale": self.scale,
"width": self.width,
"height": self.height,
"sampler_name": self.sampler,
"negative_prompt": self.ntags,
"enable_hr": self.hires,
"denoising_strength": 0.7,
"hr_scale": 2,
"hr_upscaler": "Latent"
}
print("向API发送以下参数")
print(parameters)
if self.img2img:
parameters.update({
"init_images": ["data:image/jpeg;base64,"+self.image],
"denoising_strength": self.strength,
})
await self.post_(header, post_api, parameters)
return self.result