254 lines
9.9 KiB
Python
254 lines
9.9 KiB
Python
import time
|
||
import re
|
||
|
||
from collections import deque
|
||
import aiohttp
|
||
from aiohttp.client_exceptions import ClientConnectorError, ClientOSError
|
||
from argparse import Namespace
|
||
from asyncio import get_running_loop
|
||
from nonebot import get_bot, on_shell_command
|
||
|
||
from nonebot.adapters.onebot.v11 import GroupMessageEvent, MessageSegment, Bot
|
||
from nonebot.rule import ArgumentParser
|
||
from nonebot.permission import SUPERUSER
|
||
from nonebot.log import logger
|
||
from nonebot.params import ShellCommandArgs
|
||
|
||
from .config import config
|
||
from .utils.data import lowQuality, basetag, htags
|
||
from .backend import AIDRAW
|
||
from .extension.anlas import anlas_check, anlas_set
|
||
from .extension.daylimit import DayLimit
|
||
from .utils.save import save_img
|
||
from .utils.prepocess import prepocess_tags, combine_multi_args
|
||
from .version import version
|
||
from .utils import sendtosuperuser
|
||
cd = {}
|
||
gennerating = False
|
||
wait_list = deque([])
|
||
|
||
aidraw_parser = ArgumentParser()
|
||
aidraw_parser.add_argument("tags", nargs="*", help="标签")
|
||
aidraw_parser.add_argument("-r", "--resolution", "-形状",
|
||
help="画布形状/分辨率", dest="shape")
|
||
aidraw_parser.add_argument("-c", "--scale", "-服从",
|
||
type=float, help="对输入的服从度", dest="scale")
|
||
aidraw_parser.add_argument(
|
||
"-s", "--seed", "-种子", type=int, help="种子", dest="seed")
|
||
aidraw_parser.add_argument("-b", "--batch", "-数量",
|
||
type=int, default=1, help="生成数量", dest="batch")
|
||
aidraw_parser.add_argument("-t", "--steps", "-步数",
|
||
type=int, help="步数", dest="steps")
|
||
aidraw_parser.add_argument("-u", "--ntags", "-排除",
|
||
default=" ", nargs="*", help="负面标签", dest="ntags")
|
||
aidraw_parser.add_argument("-e", "--strength", "-强度",
|
||
type=float, help="修改强度", dest="strength")
|
||
aidraw_parser.add_argument("-n", "--noise", "-噪声",
|
||
type=float, help="修改噪声", dest="noise")
|
||
aidraw_parser.add_argument("-o", "--override", "-不优化",
|
||
action='store_true', help="不使用内置优化参数", dest="override")
|
||
aidraw_parser.add_argument("--sampler", "-采样器",
|
||
default="Euler a", nargs="+", help="设置采样器", dest="sampler")
|
||
aidraw_parser.add_argument("--hires", "-高清修复",
|
||
action='store_true', help="启用高清修复", dest="hires")
|
||
|
||
aidraw = on_shell_command(
|
||
".aidraw",
|
||
aliases={"绘画", "咏唱", "召唤", "约稿", "aidraw"},
|
||
parser=aidraw_parser,
|
||
priority=5
|
||
)
|
||
|
||
|
||
@aidraw.handle()
|
||
async def aidraw_get(bot: Bot, event: GroupMessageEvent, args: Namespace = ShellCommandArgs()):
|
||
user_id = str(event.user_id)
|
||
group_id = str(event.group_id)
|
||
# 判断是否禁用,若没禁用,进入处理流程
|
||
if await config.get_value(group_id, "on"):
|
||
message = ""
|
||
# 判断最大生成数量
|
||
if args.batch > config.novelai_max:
|
||
message = message+f",批量生成数量过多,自动修改为{config.novelai_max}"
|
||
args.batch = config.novelai_max
|
||
# 判断次数限制
|
||
if config.novelai_daylimit and not await SUPERUSER(bot, event):
|
||
left = DayLimit.count(user_id, args.batch)
|
||
if left == -1:
|
||
await aidraw.finish(f"今天你的次数不够了哦")
|
||
else:
|
||
message = message + f",今天你还能够生成{left}张"
|
||
# 判断cd
|
||
nowtime = time.time()
|
||
deltatime = nowtime - cd.get(user_id, 0)
|
||
cd_ = int(await config.get_value(group_id, "cd"))
|
||
if deltatime < cd_:
|
||
await aidraw.finish(f"你冲的太快啦,请休息一下吧,剩余CD为{cd_ - int(deltatime)}s")
|
||
else:
|
||
cd[user_id] = nowtime
|
||
# 初始化参数
|
||
args.tags = await prepocess_tags(args.tags)
|
||
args.ntags = await prepocess_tags(args.ntags)
|
||
args.sampler = await combine_multi_args(args.sampler)
|
||
fifo = AIDRAW(user_id=user_id, group_id=group_id, **vars(args))
|
||
# 检测是否有18+词条
|
||
if not config.novelai_h:
|
||
pattern = re.compile(f"(\s|,|^)({htags})(\s|,|$)")
|
||
if (re.search(pattern, fifo.tags) is not None):
|
||
await aidraw.finish(f"H是不行的!")
|
||
if not args.override:
|
||
fifo.tags = basetag + await config.get_value(group_id, "tags") + "," + fifo.tags
|
||
fifo.ntags = lowQuality + fifo.ntags
|
||
|
||
# 以图生图预处理
|
||
img_url = ""
|
||
reply = event.reply
|
||
if reply:
|
||
for seg in reply.message['image']:
|
||
img_url = seg.data["url"]
|
||
for seg in event.message['image']:
|
||
img_url = seg.data["url"]
|
||
if img_url:
|
||
if config.novelai_paid:
|
||
async with aiohttp.ClientSession() as session:
|
||
logger.info(f"检测到图片,自动切换到以图生图,正在获取图片")
|
||
async with session.get(img_url) as resp:
|
||
fifo.add_image(await resp.read())
|
||
message = f",已切换至以图生图"+message
|
||
else:
|
||
await aidraw.finish(f"以图生图功能已禁用")
|
||
logger.debug(fifo)
|
||
# 初始化队列
|
||
if fifo.cost > 0:
|
||
anlascost = fifo.cost
|
||
hasanlas = await anlas_check(fifo.user_id)
|
||
if hasanlas >= anlascost:
|
||
await wait_fifo(fifo, anlascost, hasanlas - anlascost, message=message, bot=bot)
|
||
else:
|
||
await aidraw.finish(f"你的点数不足,你的剩余点数为{hasanlas}")
|
||
else:
|
||
await wait_fifo(fifo, message=message, bot=bot)
|
||
|
||
|
||
async def wait_fifo(fifo, anlascost=None, anlas=None, message="", bot=None):
|
||
# 创建队列
|
||
list_len = wait_len()
|
||
has_wait = f"排队中,你的前面还有{list_len}人"+message
|
||
no_wait = "请稍等,图片生成中"+message
|
||
if anlas:
|
||
has_wait += f"\n本次生成消耗点数{anlascost},你的剩余点数为{anlas}"
|
||
no_wait += f"\n本次生成消耗点数{anlascost},你的剩余点数为{anlas}"
|
||
if config.novelai_limit:
|
||
await aidraw.send(has_wait if list_len > 0 else no_wait)
|
||
wait_list.append(fifo)
|
||
await fifo_gennerate(bot=bot)
|
||
else:
|
||
await aidraw.send(no_wait)
|
||
await fifo_gennerate(fifo, bot)
|
||
|
||
|
||
def wait_len():
|
||
# 获取剩余队列长度
|
||
list_len = len(wait_list)
|
||
if gennerating:
|
||
list_len += 1
|
||
return list_len
|
||
|
||
|
||
async def fifo_gennerate(fifo: AIDRAW = None, bot: Bot = None):
|
||
# 队列处理
|
||
global gennerating
|
||
if not bot:
|
||
bot = get_bot()
|
||
|
||
async def generate(fifo: AIDRAW):
|
||
id = fifo.user_id if config.novelai_antireport else bot.self_id
|
||
resp = await bot.get_group_member_info(group_id=fifo.group_id, user_id=fifo.user_id)
|
||
nickname = resp["card"] or resp["nickname"]
|
||
|
||
# 开始生成
|
||
logger.info(
|
||
f"队列剩余{wait_len()}人 | 开始生成:{fifo}")
|
||
try:
|
||
im = await _run_gennerate(fifo)
|
||
except Exception as e:
|
||
logger.exception("生成失败")
|
||
message = f"生成失败,"
|
||
for i in e.args:
|
||
message += str(i)
|
||
await bot.send_group_msg(
|
||
message=message,
|
||
group_id=fifo.group_id
|
||
)
|
||
else:
|
||
logger.info(f"队列剩余{wait_len()}人 | 生成完毕:{fifo}")
|
||
if await config.get_value(fifo.group_id, "pure"):
|
||
message = MessageSegment.at(fifo.user_id)
|
||
for i in im["image"]:
|
||
message += i
|
||
message_data = await bot.send_group_msg(
|
||
message=message,
|
||
group_id=fifo.group_id,
|
||
)
|
||
else:
|
||
message = []
|
||
for i in im:
|
||
message.append(MessageSegment.node_custom(
|
||
id, nickname, i))
|
||
message_data = await bot.send_group_forward_msg(
|
||
messages=message,
|
||
group_id=fifo.group_id,
|
||
)
|
||
revoke = await config.get_value(fifo.group_id, "revoke")
|
||
if revoke:
|
||
message_id = message_data["message_id"]
|
||
loop = get_running_loop()
|
||
loop.call_later(
|
||
revoke,
|
||
lambda: loop.create_task(
|
||
bot.delete_msg(message_id=message_id)),
|
||
)
|
||
|
||
if fifo:
|
||
await generate(fifo)
|
||
|
||
if not gennerating:
|
||
logger.info("队列开始")
|
||
gennerating = True
|
||
|
||
while len(wait_list) > 0:
|
||
fifo = wait_list.popleft()
|
||
try:
|
||
await generate(fifo)
|
||
except:
|
||
pass
|
||
|
||
gennerating = False
|
||
logger.info("队列结束")
|
||
await version.check_update()
|
||
|
||
|
||
async def _run_gennerate(fifo: AIDRAW):
|
||
# 处理单个请求
|
||
try:
|
||
await fifo.post()
|
||
except ClientConnectorError:
|
||
await sendtosuperuser(f"远程服务器拒绝连接,请检查配置是否正确,服务器是否已经启动")
|
||
raise RuntimeError(f"远程服务器拒绝连接,请检查配置是否正确,服务器是否已经启动")
|
||
except ClientOSError:
|
||
await sendtosuperuser(f"远程服务器崩掉了欸……")
|
||
raise RuntimeError(f"服务器崩掉了欸……请等待主人修复吧")
|
||
# 若启用ai检定,取消注释下行代码,并将构造消息体部分注释
|
||
# message = await check_safe_method(fifo, img_bytes, message)
|
||
# 构造消息体并保存图片
|
||
message = f"{config.novelai_mode}绘画完成~"
|
||
for i in fifo.result:
|
||
await save_img(fifo, i, fifo.group_id)
|
||
message += MessageSegment.image(i)
|
||
for i in fifo.format():
|
||
message += MessageSegment.text(i)
|
||
# 扣除点数
|
||
if fifo.cost > 0:
|
||
await anlas_set(fifo.user_id, -fifo.cost)
|
||
return message
|