新增指定sampler, hires的功能
修复tags中方括号编码不正常的问题
This commit is contained in:
		
							
								
								
									
										20
									
								
								backend/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								backend/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
from ..config import config
 | 
			
		||||
"""def AIDRAW():
 | 
			
		||||
    if config.novelai_mode=="novelai":
 | 
			
		||||
        from .novelai import AIDRAW
 | 
			
		||||
    elif config.novelai_mode=="naifu":
 | 
			
		||||
        from .naifu import AIDRAW
 | 
			
		||||
    elif config.novelai_mode=="sd":
 | 
			
		||||
        from .sd import AIDRAW
 | 
			
		||||
    else:
 | 
			
		||||
        raise RuntimeError(f"错误的mode设置,支持的字符串为'novelai','naifu','sd'")
 | 
			
		||||
    return AIDRAW()"""
 | 
			
		||||
 | 
			
		||||
if config.novelai_mode=="novelai":
 | 
			
		||||
    from .novelai import AIDRAW
 | 
			
		||||
elif config.novelai_mode=="naifu":
 | 
			
		||||
    from .naifu import AIDRAW
 | 
			
		||||
elif config.novelai_mode=="sd":
 | 
			
		||||
    from .sd import AIDRAW
 | 
			
		||||
else:
 | 
			
		||||
    raise RuntimeError(f"错误的mode设置,支持的字符串为'novelai','naifu','sd'")
 | 
			
		||||
							
								
								
									
										266
									
								
								backend/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										266
									
								
								backend/base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,266 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import base64
 | 
			
		||||
import random
 | 
			
		||||
import time
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
from nonebot import get_driver
 | 
			
		||||
from nonebot.log import logger
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from ..config import config
 | 
			
		||||
from ..utils import png2jpg
 | 
			
		||||
from ..utils.data import shapemap
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AIDRAW_BASE:
 | 
			
		||||
    max_resolution: int = 16
 | 
			
		||||
    sampler: str
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
        group_id: str,
 | 
			
		||||
        tags: str = "",
 | 
			
		||||
        ntags: str = "",
 | 
			
		||||
        seed: int = None,
 | 
			
		||||
        scale: int = None,
 | 
			
		||||
        steps: int = None,
 | 
			
		||||
        batch: int = None,
 | 
			
		||||
        strength: float = None,
 | 
			
		||||
        noise: float = None,
 | 
			
		||||
        shape: str = "p",
 | 
			
		||||
        model: str = None,
 | 
			
		||||
        sampler: str = "",
 | 
			
		||||
        hires: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        AI绘画的核心部分,将与服务器通信的过程包装起来,并方便扩展服务器类型
 | 
			
		||||
 | 
			
		||||
        :user_id: 用户id,必须
 | 
			
		||||
        :group_id: 群聊id,如果是私聊则应置为0,必须
 | 
			
		||||
        :tags: 图像的标签
 | 
			
		||||
        :ntags: 图像的反面标签
 | 
			
		||||
        :seed: 生成的种子,不指定的情况下随机生成
 | 
			
		||||
        :scale: 标签的参考度,值越高越贴近于标签,但可能产生过度锐化。范围为0-30,默认11
 | 
			
		||||
        :steps: 训练步数。范围为1-50,默认28.以图生图时强制50
 | 
			
		||||
        :batch: 同时生成数量
 | 
			
		||||
        :strength: 以图生图时使用,变化的强度。范围为0-1,默认0.7
 | 
			
		||||
        :noise: 以图生图时使用,变化的噪音,数值越大细节越多,但可能产生伪影,不建议超过strength。范围0-1,默认0.2
 | 
			
		||||
        :shape: 图像的形状,支持"p""s""l"三种,同时支持以"x"分割宽高的指定分辨率。
 | 
			
		||||
                该值会被设置限制,并不会严格遵守输入
 | 
			
		||||
                类初始化后,该参数会被拆分为:width:和:height:
 | 
			
		||||
        :model: 指定的模型,模型名称在配置文件中手动设置。不指定模型则按照负载均衡自动选择
 | 
			
		||||
 | 
			
		||||
        AIDRAW还包含了以下几种内置的参数
 | 
			
		||||
        :status: 记录了AIDRAW的状态,默认为0等待中(处理中)
 | 
			
		||||
                非0的值为运行完成后的状态值,200和201为正常运行,其余值为产生错误
 | 
			
		||||
        :result: 当正常运行完成后,该参数为一个包含了生成图片bytes信息的数组
 | 
			
		||||
        :maxresolution: 一般不用管,用于限制不同服务器的最大分辨率
 | 
			
		||||
                如果你的SD经过了魔改支持更大分辨率可以修改该值并重新设置宽高
 | 
			
		||||
        :cost: 记录了本次生成需要花费多少点数,自动计算
 | 
			
		||||
        :signal: asyncio.Event类,可以作为信号使用。仅占位,需要自行实现相关方法
 | 
			
		||||
        """
 | 
			
		||||
        self.status: int = 0
 | 
			
		||||
        self.result: list = []
 | 
			
		||||
        self.signal: asyncio.Event = None
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.time = time.strftime("%Y-%m-%d %H:%M:%S")
 | 
			
		||||
        self.user_id: str = user_id
 | 
			
		||||
        self.tags: str = tags
 | 
			
		||||
        self.seed: list[int] = [seed or random.randint(0, 4294967295)]
 | 
			
		||||
        self.group_id: str = group_id
 | 
			
		||||
        self.scale: int = int(scale or 11)
 | 
			
		||||
        self.strength: float = strength or 0.7
 | 
			
		||||
        self.batch: int = batch or 1
 | 
			
		||||
        self.steps: int = steps or 28
 | 
			
		||||
        self.noise: float = noise or 0.2
 | 
			
		||||
        self.ntags: str = ntags
 | 
			
		||||
        self.img2img: bool = False
 | 
			
		||||
        self.image: str = None
 | 
			
		||||
        self.width, self.height = self.extract_shape(shape)
 | 
			
		||||
        self.sampler = sampler
 | 
			
		||||
        self.hires = hires
 | 
			
		||||
        # 数值合法检查
 | 
			
		||||
        if self.steps <= 0 or self.steps > (50 if not config.novelai_paid else 28):
 | 
			
		||||
            self.steps = 28
 | 
			
		||||
        if self.strength < 0 or self.strength > 1:
 | 
			
		||||
            self.strength = 0.7
 | 
			
		||||
        if self.noise < 0 or self.noise > 1:
 | 
			
		||||
            self.noise = 0.2
 | 
			
		||||
        if self.scale <= 0 or self.scale > 30:
 | 
			
		||||
            self.scale = 11
 | 
			
		||||
        # 多图时随机填充剩余seed
 | 
			
		||||
        for i in range(self.batch - 1):
 | 
			
		||||
            self.seed.append(random.randint(0, 4294967295))
 | 
			
		||||
        # 计算cost
 | 
			
		||||
        self.update_cost()
 | 
			
		||||
 | 
			
		||||
    def extract_shape(self, shape: str):
 | 
			
		||||
        """
 | 
			
		||||
        将shape拆分为width和height
 | 
			
		||||
        """
 | 
			
		||||
        if shape:
 | 
			
		||||
            if "x" in shape:
 | 
			
		||||
                width, height, *_ = shape.split("x")
 | 
			
		||||
                if width.isdigit() and height.isdigit():
 | 
			
		||||
                    return self.shape_set(int(width), int(height))
 | 
			
		||||
                else:
 | 
			
		||||
                    return shapemap.get(shape)
 | 
			
		||||
            else:
 | 
			
		||||
                return shapemap.get(shape)
 | 
			
		||||
        else:
 | 
			
		||||
            return (512, 768)
 | 
			
		||||
 | 
			
		||||
    def update_cost(self):
 | 
			
		||||
        """
 | 
			
		||||
        更新cost
 | 
			
		||||
        """
 | 
			
		||||
        if config.novelai_paid == 1:
 | 
			
		||||
            anlas = 0
 | 
			
		||||
            if (self.width * self.height > 409600) or self.image or self.batch > 1:
 | 
			
		||||
                anlas = round(
 | 
			
		||||
                    self.width
 | 
			
		||||
                    * self.height
 | 
			
		||||
                    * self.strength
 | 
			
		||||
                    * self.batch
 | 
			
		||||
                    * self.steps
 | 
			
		||||
                    / 2293750
 | 
			
		||||
                )
 | 
			
		||||
                if anlas < 2:
 | 
			
		||||
                    anlas = 2
 | 
			
		||||
            if self.user_id in get_driver().config.superusers:
 | 
			
		||||
                self.cost = 0
 | 
			
		||||
            else:
 | 
			
		||||
                self.cost = anlas
 | 
			
		||||
        elif config.novelai_paid == 2:
 | 
			
		||||
            anlas = round(
 | 
			
		||||
                self.width
 | 
			
		||||
                * self.height
 | 
			
		||||
                * self.strength
 | 
			
		||||
                * self.batch
 | 
			
		||||
                * self.steps
 | 
			
		||||
                / 2293750
 | 
			
		||||
            )
 | 
			
		||||
            if anlas < 2:
 | 
			
		||||
                anlas = 2
 | 
			
		||||
            if self.user_id in get_driver().config.superusers:
 | 
			
		||||
                self.cost = 0
 | 
			
		||||
            else:
 | 
			
		||||
                self.cost = anlas
 | 
			
		||||
        else:
 | 
			
		||||
            self.cost = 0
 | 
			
		||||
 | 
			
		||||
    def add_image(self, image: bytes):
 | 
			
		||||
        """
 | 
			
		||||
        向类中添加图片,将其转化为以图生图模式
 | 
			
		||||
        也可用于修改类中已经存在的图片
 | 
			
		||||
        """
 | 
			
		||||
        # 根据图片重写长宽
 | 
			
		||||
        tmpfile = BytesIO(image)
 | 
			
		||||
        image_ = Image.open(tmpfile)
 | 
			
		||||
        width, height = image_.size
 | 
			
		||||
        self.width, self.height = self.shape_set(width, height)
 | 
			
		||||
        self.image = str(base64.b64encode(image), "utf-8")
 | 
			
		||||
        self.steps = 50
 | 
			
		||||
        self.img2img = True
 | 
			
		||||
        self.update_cost()
 | 
			
		||||
 | 
			
		||||
    def shape_set(self, width: int, height: int):
 | 
			
		||||
        """
 | 
			
		||||
        设置宽高
 | 
			
		||||
        """
 | 
			
		||||
        limit = 1024 if config.paid else 640
 | 
			
		||||
        if width * height > pow(min(config.novelai_size, limit), 2):
 | 
			
		||||
            if width <= height:
 | 
			
		||||
                ratio = height / width
 | 
			
		||||
                width: float = config.novelai_size / pow(ratio, 0.5)
 | 
			
		||||
                height: float = width * ratio
 | 
			
		||||
            else:
 | 
			
		||||
                ratio = width / height
 | 
			
		||||
                height: float = config.novelai_size / pow(ratio, 0.5)
 | 
			
		||||
                width: float = height * ratio
 | 
			
		||||
        base = round(max(width, height) / 64)
 | 
			
		||||
        if base > self.max_resolution:
 | 
			
		||||
            base = self.max_resolution
 | 
			
		||||
        if width <= height:
 | 
			
		||||
            return (round(width / height * base) * 64, 64 * base)
 | 
			
		||||
        else:
 | 
			
		||||
            return (64 * base, round(height / width * base) * 64)
 | 
			
		||||
 | 
			
		||||
    async def post_(self, header: dict, post_api: str, json: dict):
 | 
			
		||||
        """
 | 
			
		||||
        向服务器发送请求的核心函数,不要直接调用,请使用post函数
 | 
			
		||||
        :header: 请求头
 | 
			
		||||
        :post_api: 请求地址
 | 
			
		||||
        :json: 请求体
 | 
			
		||||
        """
 | 
			
		||||
        # 请求交互
 | 
			
		||||
        async with aiohttp.ClientSession(headers=header) as session:
 | 
			
		||||
            # 向服务器发送请求
 | 
			
		||||
            async with session.post(post_api, json=json) as resp:
 | 
			
		||||
                if resp.status not in [200, 201]:
 | 
			
		||||
                    logger.error(await resp.text())
 | 
			
		||||
                    raise RuntimeError(f"与服务器沟通时发生{resp.status}错误")
 | 
			
		||||
                img = await self.fromresp(resp)
 | 
			
		||||
                logger.debug(f"获取到返回图片,正在处理")
 | 
			
		||||
 | 
			
		||||
                # 将图片转化为jpg
 | 
			
		||||
                if config.novelai_save == 1:
 | 
			
		||||
                    image_new = await png2jpg(img)
 | 
			
		||||
                else:
 | 
			
		||||
                    image_new = base64.b64decode(img)
 | 
			
		||||
        self.result.append(image_new)
 | 
			
		||||
        return image_new
 | 
			
		||||
 | 
			
		||||
    async def fromresp(self, resp):
 | 
			
		||||
        """
 | 
			
		||||
        处理请求的返回内容,不要直接调用,请使用post函数
 | 
			
		||||
        """
 | 
			
		||||
        img: str = await resp.text()
 | 
			
		||||
        return img.split("data:")[1]
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        """
 | 
			
		||||
        运行核心函数,发送请求并处理
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        return (
 | 
			
		||||
            "seed",
 | 
			
		||||
            "scale",
 | 
			
		||||
            "strength",
 | 
			
		||||
            "noise",
 | 
			
		||||
            "sampler",
 | 
			
		||||
            "model",
 | 
			
		||||
            "steps",
 | 
			
		||||
            "width",
 | 
			
		||||
            "height",
 | 
			
		||||
            "img2img",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, item):
 | 
			
		||||
        return getattr(self, item)
 | 
			
		||||
 | 
			
		||||
    def format(self):
 | 
			
		||||
        dict_self = dict(self)
 | 
			
		||||
        list = []
 | 
			
		||||
        str = ""
 | 
			
		||||
        for i, v in dict_self.items():
 | 
			
		||||
            str += f"{i}={v}\n"
 | 
			
		||||
        list.append(str)
 | 
			
		||||
        list.append(f"tags={self.tags}\n")
 | 
			
		||||
        list.append(f"ntags={self.ntags}")
 | 
			
		||||
        return list
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return (
 | 
			
		||||
            f"time={self.time}\nuser_id={self.user_id}\ngroup_id={self.group_id}\ncost={self.cost}\nbatch={self.batch}\n"
 | 
			
		||||
            + "".join(self.format())
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return self.__repr__().replace("\n", ";")
 | 
			
		||||
							
								
								
									
										34
									
								
								backend/naifu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								backend/naifu.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
			
		||||
from .base import AIDRAW_BASE
 | 
			
		||||
from ..config import config
 | 
			
		||||
class AIDRAW(AIDRAW_BASE):
 | 
			
		||||
    """队列中的单个请求"""
 | 
			
		||||
 | 
			
		||||
    async def post(self):
 | 
			
		||||
        header = {
 | 
			
		||||
            "content-type": "application/json",
 | 
			
		||||
            "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
 | 
			
		||||
        }
 | 
			
		||||
        site=config.novelai_site or "127.0.0.1:6969"
 | 
			
		||||
        post_api="http://"+site + "/generate-stream"
 | 
			
		||||
        for i in range(self.batch):
 | 
			
		||||
            parameters = {
 | 
			
		||||
                "prompt":self.tags,
 | 
			
		||||
                "width": self.width,
 | 
			
		||||
                "height": self.height,
 | 
			
		||||
                "qualityToggle": False,
 | 
			
		||||
                "scale": self.scale,
 | 
			
		||||
                "sampler": self.sampler,
 | 
			
		||||
                "steps": self.steps,
 | 
			
		||||
                "seed": self.seed[i],
 | 
			
		||||
                "n_samples": 1,
 | 
			
		||||
                "ucPreset": 0,
 | 
			
		||||
                "uc": self.ntags,
 | 
			
		||||
            }
 | 
			
		||||
            if self.img2img:
 | 
			
		||||
                parameters.update({
 | 
			
		||||
                    "image": self.image,
 | 
			
		||||
                    "strength": self.strength,
 | 
			
		||||
                    "noise": self.noise
 | 
			
		||||
                })
 | 
			
		||||
            await self.post_(header, post_api,parameters)
 | 
			
		||||
        return self.result
 | 
			
		||||
							
								
								
									
										44
									
								
								backend/novelai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								backend/novelai.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
from ..config import config
 | 
			
		||||
from .base import AIDRAW_BASE
 | 
			
		||||
 | 
			
		||||
class AIDRAW(AIDRAW_BASE):
 | 
			
		||||
    """队列中的单个请求"""
 | 
			
		||||
    model: str = "nai-diffusion" if config.novelai_h else "safe-diffusion"
 | 
			
		||||
 | 
			
		||||
    async def post(self):
 | 
			
		||||
        # 获取请求体
 | 
			
		||||
        header = {
 | 
			
		||||
            "authorization": "Bearer " + config.novelai_token,
 | 
			
		||||
            ":authority": "https://api.novelai.net",
 | 
			
		||||
            ":path": "/ai/generate-image",
 | 
			
		||||
            "content-type": "application/json",
 | 
			
		||||
            "referer": "https://novelai.net",
 | 
			
		||||
            "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
 | 
			
		||||
        }
 | 
			
		||||
        post_api = "https://api.novelai.net/ai/generate-image"
 | 
			
		||||
        for i in range(self.batch):
 | 
			
		||||
            parameters = {
 | 
			
		||||
                "width": self.width,
 | 
			
		||||
                "height": self.height,
 | 
			
		||||
                "qualityToggle": False,
 | 
			
		||||
                "scale": self.scale,
 | 
			
		||||
                "sampler": self.sampler,
 | 
			
		||||
                "steps": self.steps,
 | 
			
		||||
                "seed": self.seed[i],
 | 
			
		||||
                "n_samples": 1,
 | 
			
		||||
                "ucPreset": 0,
 | 
			
		||||
                "uc": self.ntags,
 | 
			
		||||
            }
 | 
			
		||||
            if self.img2img:
 | 
			
		||||
                parameters.update({
 | 
			
		||||
                    "image": self.image,
 | 
			
		||||
                    "strength": self.strength,
 | 
			
		||||
                    "noise": self.noise
 | 
			
		||||
                })
 | 
			
		||||
            json= {
 | 
			
		||||
                "input": self.tags,
 | 
			
		||||
                "model": self.model,
 | 
			
		||||
                "parameters": parameters
 | 
			
		||||
            }
 | 
			
		||||
            await self.post_(header, post_api,json)
 | 
			
		||||
        return self.result
 | 
			
		||||
							
								
								
									
										43
									
								
								backend/sd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								backend/sd.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
			
		||||
from .base import AIDRAW_BASE
 | 
			
		||||
from ..config import config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AIDRAW(AIDRAW_BASE):
 | 
			
		||||
    """队列中的单个请求"""
 | 
			
		||||
    max_resolution: int = 32
 | 
			
		||||
 | 
			
		||||
    async def fromresp(self, resp):
 | 
			
		||||
        img: dict = await resp.json()
 | 
			
		||||
        return img["images"][0]
 | 
			
		||||
 | 
			
		||||
    async def post(self):
 | 
			
		||||
        site=config.novelai_site or "127.0.0.1:7860"
 | 
			
		||||
        header = {
 | 
			
		||||
            "content-type": "application/json",
 | 
			
		||||
            "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
 | 
			
		||||
        }
 | 
			
		||||
        post_api = f"http://{site}/sdapi/v1/img2img" if self.img2img else f"http://{site}/sdapi/v1/txt2img"
 | 
			
		||||
        for i in range(self.batch):
 | 
			
		||||
            parameters = {
 | 
			
		||||
                "prompt": self.tags,
 | 
			
		||||
                "seed": self.seed[i],
 | 
			
		||||
                "steps": self.steps,
 | 
			
		||||
                "cfg_scale": self.scale,
 | 
			
		||||
                "width": self.width,
 | 
			
		||||
                "height": self.height,
 | 
			
		||||
                "sampler_name": self.sampler,
 | 
			
		||||
                "negative_prompt": self.ntags,
 | 
			
		||||
                "enable_hr": self.hires,
 | 
			
		||||
                "denoising_strength": 0.7,
 | 
			
		||||
                "hr_scale": 2,
 | 
			
		||||
                "hr_upscaler": "Latent"
 | 
			
		||||
            }
 | 
			
		||||
            print("向API发送以下参数:")
 | 
			
		||||
            print(parameters)
 | 
			
		||||
            if self.img2img:
 | 
			
		||||
                parameters.update({
 | 
			
		||||
                    "init_images": ["data:image/jpeg;base64,"+self.image],
 | 
			
		||||
                    "denoising_strength": self.strength,
 | 
			
		||||
                })
 | 
			
		||||
            await self.post_(header, post_api, parameters)
 | 
			
		||||
        return self.result
 | 
			
		||||
		Reference in New Issue
	
	Block a user