新增指定sampler, hires的功能
修复tags中方括号编码不正常的问题
This commit is contained in:
20
backend/__init__.py
Normal file
20
backend/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from ..config import config
|
||||
"""def AIDRAW():
|
||||
if config.novelai_mode=="novelai":
|
||||
from .novelai import AIDRAW
|
||||
elif config.novelai_mode=="naifu":
|
||||
from .naifu import AIDRAW
|
||||
elif config.novelai_mode=="sd":
|
||||
from .sd import AIDRAW
|
||||
else:
|
||||
raise RuntimeError(f"错误的mode设置,支持的字符串为'novelai','naifu','sd'")
|
||||
return AIDRAW()"""
|
||||
|
||||
if config.novelai_mode=="novelai":
|
||||
from .novelai import AIDRAW
|
||||
elif config.novelai_mode=="naifu":
|
||||
from .naifu import AIDRAW
|
||||
elif config.novelai_mode=="sd":
|
||||
from .sd import AIDRAW
|
||||
else:
|
||||
raise RuntimeError(f"错误的mode设置,支持的字符串为'novelai','naifu','sd'")
|
||||
266
backend/base.py
Normal file
266
backend/base.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import random
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
from nonebot import get_driver
|
||||
from nonebot.log import logger
|
||||
from PIL import Image
|
||||
|
||||
from ..config import config
|
||||
from ..utils import png2jpg
|
||||
from ..utils.data import shapemap
|
||||
|
||||
|
||||
class AIDRAW_BASE:
|
||||
max_resolution: int = 16
|
||||
sampler: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
group_id: str,
|
||||
tags: str = "",
|
||||
ntags: str = "",
|
||||
seed: int = None,
|
||||
scale: int = None,
|
||||
steps: int = None,
|
||||
batch: int = None,
|
||||
strength: float = None,
|
||||
noise: float = None,
|
||||
shape: str = "p",
|
||||
model: str = None,
|
||||
sampler: str = "",
|
||||
hires: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
AI绘画的核心部分,将与服务器通信的过程包装起来,并方便扩展服务器类型
|
||||
|
||||
:user_id: 用户id,必须
|
||||
:group_id: 群聊id,如果是私聊则应置为0,必须
|
||||
:tags: 图像的标签
|
||||
:ntags: 图像的反面标签
|
||||
:seed: 生成的种子,不指定的情况下随机生成
|
||||
:scale: 标签的参考度,值越高越贴近于标签,但可能产生过度锐化。范围为0-30,默认11
|
||||
:steps: 训练步数。范围为1-50,默认28.以图生图时强制50
|
||||
:batch: 同时生成数量
|
||||
:strength: 以图生图时使用,变化的强度。范围为0-1,默认0.7
|
||||
:noise: 以图生图时使用,变化的噪音,数值越大细节越多,但可能产生伪影,不建议超过strength。范围0-1,默认0.2
|
||||
:shape: 图像的形状,支持"p""s""l"三种,同时支持以"x"分割宽高的指定分辨率。
|
||||
该值会被设置限制,并不会严格遵守输入
|
||||
类初始化后,该参数会被拆分为:width:和:height:
|
||||
:model: 指定的模型,模型名称在配置文件中手动设置。不指定模型则按照负载均衡自动选择
|
||||
|
||||
AIDRAW还包含了以下几种内置的参数
|
||||
:status: 记录了AIDRAW的状态,默认为0等待中(处理中)
|
||||
非0的值为运行完成后的状态值,200和201为正常运行,其余值为产生错误
|
||||
:result: 当正常运行完成后,该参数为一个包含了生成图片bytes信息的数组
|
||||
:maxresolution: 一般不用管,用于限制不同服务器的最大分辨率
|
||||
如果你的SD经过了魔改支持更大分辨率可以修改该值并重新设置宽高
|
||||
:cost: 记录了本次生成需要花费多少点数,自动计算
|
||||
:signal: asyncio.Event类,可以作为信号使用。仅占位,需要自行实现相关方法
|
||||
"""
|
||||
self.status: int = 0
|
||||
self.result: list = []
|
||||
self.signal: asyncio.Event = None
|
||||
self.model = model
|
||||
self.time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.user_id: str = user_id
|
||||
self.tags: str = tags
|
||||
self.seed: list[int] = [seed or random.randint(0, 4294967295)]
|
||||
self.group_id: str = group_id
|
||||
self.scale: int = int(scale or 11)
|
||||
self.strength: float = strength or 0.7
|
||||
self.batch: int = batch or 1
|
||||
self.steps: int = steps or 28
|
||||
self.noise: float = noise or 0.2
|
||||
self.ntags: str = ntags
|
||||
self.img2img: bool = False
|
||||
self.image: str = None
|
||||
self.width, self.height = self.extract_shape(shape)
|
||||
self.sampler = sampler
|
||||
self.hires = hires
|
||||
# 数值合法检查
|
||||
if self.steps <= 0 or self.steps > (50 if not config.novelai_paid else 28):
|
||||
self.steps = 28
|
||||
if self.strength < 0 or self.strength > 1:
|
||||
self.strength = 0.7
|
||||
if self.noise < 0 or self.noise > 1:
|
||||
self.noise = 0.2
|
||||
if self.scale <= 0 or self.scale > 30:
|
||||
self.scale = 11
|
||||
# 多图时随机填充剩余seed
|
||||
for i in range(self.batch - 1):
|
||||
self.seed.append(random.randint(0, 4294967295))
|
||||
# 计算cost
|
||||
self.update_cost()
|
||||
|
||||
def extract_shape(self, shape: str):
|
||||
"""
|
||||
将shape拆分为width和height
|
||||
"""
|
||||
if shape:
|
||||
if "x" in shape:
|
||||
width, height, *_ = shape.split("x")
|
||||
if width.isdigit() and height.isdigit():
|
||||
return self.shape_set(int(width), int(height))
|
||||
else:
|
||||
return shapemap.get(shape)
|
||||
else:
|
||||
return shapemap.get(shape)
|
||||
else:
|
||||
return (512, 768)
|
||||
|
||||
def update_cost(self):
|
||||
"""
|
||||
更新cost
|
||||
"""
|
||||
if config.novelai_paid == 1:
|
||||
anlas = 0
|
||||
if (self.width * self.height > 409600) or self.image or self.batch > 1:
|
||||
anlas = round(
|
||||
self.width
|
||||
* self.height
|
||||
* self.strength
|
||||
* self.batch
|
||||
* self.steps
|
||||
/ 2293750
|
||||
)
|
||||
if anlas < 2:
|
||||
anlas = 2
|
||||
if self.user_id in get_driver().config.superusers:
|
||||
self.cost = 0
|
||||
else:
|
||||
self.cost = anlas
|
||||
elif config.novelai_paid == 2:
|
||||
anlas = round(
|
||||
self.width
|
||||
* self.height
|
||||
* self.strength
|
||||
* self.batch
|
||||
* self.steps
|
||||
/ 2293750
|
||||
)
|
||||
if anlas < 2:
|
||||
anlas = 2
|
||||
if self.user_id in get_driver().config.superusers:
|
||||
self.cost = 0
|
||||
else:
|
||||
self.cost = anlas
|
||||
else:
|
||||
self.cost = 0
|
||||
|
||||
def add_image(self, image: bytes):
|
||||
"""
|
||||
向类中添加图片,将其转化为以图生图模式
|
||||
也可用于修改类中已经存在的图片
|
||||
"""
|
||||
# 根据图片重写长宽
|
||||
tmpfile = BytesIO(image)
|
||||
image_ = Image.open(tmpfile)
|
||||
width, height = image_.size
|
||||
self.width, self.height = self.shape_set(width, height)
|
||||
self.image = str(base64.b64encode(image), "utf-8")
|
||||
self.steps = 50
|
||||
self.img2img = True
|
||||
self.update_cost()
|
||||
|
||||
def shape_set(self, width: int, height: int):
|
||||
"""
|
||||
设置宽高
|
||||
"""
|
||||
limit = 1024 if config.paid else 640
|
||||
if width * height > pow(min(config.novelai_size, limit), 2):
|
||||
if width <= height:
|
||||
ratio = height / width
|
||||
width: float = config.novelai_size / pow(ratio, 0.5)
|
||||
height: float = width * ratio
|
||||
else:
|
||||
ratio = width / height
|
||||
height: float = config.novelai_size / pow(ratio, 0.5)
|
||||
width: float = height * ratio
|
||||
base = round(max(width, height) / 64)
|
||||
if base > self.max_resolution:
|
||||
base = self.max_resolution
|
||||
if width <= height:
|
||||
return (round(width / height * base) * 64, 64 * base)
|
||||
else:
|
||||
return (64 * base, round(height / width * base) * 64)
|
||||
|
||||
async def post_(self, header: dict, post_api: str, json: dict):
|
||||
"""
|
||||
向服务器发送请求的核心函数,不要直接调用,请使用post函数
|
||||
:header: 请求头
|
||||
:post_api: 请求地址
|
||||
:json: 请求体
|
||||
"""
|
||||
# 请求交互
|
||||
async with aiohttp.ClientSession(headers=header) as session:
|
||||
# 向服务器发送请求
|
||||
async with session.post(post_api, json=json) as resp:
|
||||
if resp.status not in [200, 201]:
|
||||
logger.error(await resp.text())
|
||||
raise RuntimeError(f"与服务器沟通时发生{resp.status}错误")
|
||||
img = await self.fromresp(resp)
|
||||
logger.debug(f"获取到返回图片,正在处理")
|
||||
|
||||
# 将图片转化为jpg
|
||||
if config.novelai_save == 1:
|
||||
image_new = await png2jpg(img)
|
||||
else:
|
||||
image_new = base64.b64decode(img)
|
||||
self.result.append(image_new)
|
||||
return image_new
|
||||
|
||||
async def fromresp(self, resp):
|
||||
"""
|
||||
处理请求的返回内容,不要直接调用,请使用post函数
|
||||
"""
|
||||
img: str = await resp.text()
|
||||
return img.split("data:")[1]
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行核心函数,发送请求并处理
|
||||
"""
|
||||
pass
|
||||
|
||||
def keys(self):
|
||||
return (
|
||||
"seed",
|
||||
"scale",
|
||||
"strength",
|
||||
"noise",
|
||||
"sampler",
|
||||
"model",
|
||||
"steps",
|
||||
"width",
|
||||
"height",
|
||||
"img2img",
|
||||
)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
def format(self):
|
||||
dict_self = dict(self)
|
||||
list = []
|
||||
str = ""
|
||||
for i, v in dict_self.items():
|
||||
str += f"{i}={v}\n"
|
||||
list.append(str)
|
||||
list.append(f"tags={self.tags}\n")
|
||||
list.append(f"ntags={self.ntags}")
|
||||
return list
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"time={self.time}\nuser_id={self.user_id}\ngroup_id={self.group_id}\ncost={self.cost}\nbatch={self.batch}\n"
|
||||
+ "".join(self.format())
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__().replace("\n", ";")
|
||||
34
backend/naifu.py
Normal file
34
backend/naifu.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from .base import AIDRAW_BASE
|
||||
from ..config import config
|
||||
class AIDRAW(AIDRAW_BASE):
|
||||
"""队列中的单个请求"""
|
||||
|
||||
async def post(self):
|
||||
header = {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
|
||||
}
|
||||
site=config.novelai_site or "127.0.0.1:6969"
|
||||
post_api="http://"+site + "/generate-stream"
|
||||
for i in range(self.batch):
|
||||
parameters = {
|
||||
"prompt":self.tags,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"qualityToggle": False,
|
||||
"scale": self.scale,
|
||||
"sampler": self.sampler,
|
||||
"steps": self.steps,
|
||||
"seed": self.seed[i],
|
||||
"n_samples": 1,
|
||||
"ucPreset": 0,
|
||||
"uc": self.ntags,
|
||||
}
|
||||
if self.img2img:
|
||||
parameters.update({
|
||||
"image": self.image,
|
||||
"strength": self.strength,
|
||||
"noise": self.noise
|
||||
})
|
||||
await self.post_(header, post_api,parameters)
|
||||
return self.result
|
||||
44
backend/novelai.py
Normal file
44
backend/novelai.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from ..config import config
|
||||
from .base import AIDRAW_BASE
|
||||
|
||||
class AIDRAW(AIDRAW_BASE):
|
||||
"""队列中的单个请求"""
|
||||
model: str = "nai-diffusion" if config.novelai_h else "safe-diffusion"
|
||||
|
||||
async def post(self):
|
||||
# 获取请求体
|
||||
header = {
|
||||
"authorization": "Bearer " + config.novelai_token,
|
||||
":authority": "https://api.novelai.net",
|
||||
":path": "/ai/generate-image",
|
||||
"content-type": "application/json",
|
||||
"referer": "https://novelai.net",
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
|
||||
}
|
||||
post_api = "https://api.novelai.net/ai/generate-image"
|
||||
for i in range(self.batch):
|
||||
parameters = {
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"qualityToggle": False,
|
||||
"scale": self.scale,
|
||||
"sampler": self.sampler,
|
||||
"steps": self.steps,
|
||||
"seed": self.seed[i],
|
||||
"n_samples": 1,
|
||||
"ucPreset": 0,
|
||||
"uc": self.ntags,
|
||||
}
|
||||
if self.img2img:
|
||||
parameters.update({
|
||||
"image": self.image,
|
||||
"strength": self.strength,
|
||||
"noise": self.noise
|
||||
})
|
||||
json= {
|
||||
"input": self.tags,
|
||||
"model": self.model,
|
||||
"parameters": parameters
|
||||
}
|
||||
await self.post_(header, post_api,json)
|
||||
return self.result
|
||||
43
backend/sd.py
Normal file
43
backend/sd.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from .base import AIDRAW_BASE
|
||||
from ..config import config
|
||||
|
||||
|
||||
class AIDRAW(AIDRAW_BASE):
|
||||
"""队列中的单个请求"""
|
||||
max_resolution: int = 32
|
||||
|
||||
async def fromresp(self, resp):
|
||||
img: dict = await resp.json()
|
||||
return img["images"][0]
|
||||
|
||||
async def post(self):
|
||||
site=config.novelai_site or "127.0.0.1:7860"
|
||||
header = {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
|
||||
}
|
||||
post_api = f"http://{site}/sdapi/v1/img2img" if self.img2img else f"http://{site}/sdapi/v1/txt2img"
|
||||
for i in range(self.batch):
|
||||
parameters = {
|
||||
"prompt": self.tags,
|
||||
"seed": self.seed[i],
|
||||
"steps": self.steps,
|
||||
"cfg_scale": self.scale,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"sampler_name": self.sampler,
|
||||
"negative_prompt": self.ntags,
|
||||
"enable_hr": self.hires,
|
||||
"denoising_strength": 0.7,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent"
|
||||
}
|
||||
print("向API发送以下参数:")
|
||||
print(parameters)
|
||||
if self.img2img:
|
||||
parameters.update({
|
||||
"init_images": ["data:image/jpeg;base64,"+self.image],
|
||||
"denoising_strength": self.strength,
|
||||
})
|
||||
await self.post_(header, post_api, parameters)
|
||||
return self.result
|
||||
Reference in New Issue
Block a user