新增指定sampler, hires的功能

修复tags中方括号编码不正常的问题
This commit is contained in:
Eigeen 2023-03-01 21:14:18 +08:00
commit deae8e4888
29 changed files with 1503 additions and 0 deletions

138
.gitignore vendored Normal file
View File

@ -0,0 +1,138 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

10
__init__.py Normal file
View File

@ -0,0 +1,10 @@
from . import config, manage
from .aidraw import AIDRAW
from nonebot.plugin import PluginMetadata
from .extension.deepdanbooru import deepdanbooru
__plugin_meta__ = PluginMetadata(
name="AI绘图",
description="调用novelai进行二次元AI绘图",
usage=f"基础用法:\n.aidraw[指令] [空格] loli,[参数]\n示例:.aidraw loli,cute,kawaii,\n项目地址:https://github.com/Mutsukibot/tree/nonebot-plugin-novelai\n说明书https://sena-nana.github.io/MutsukiDocs/",
)
__all__ = ["AIDRAW", "__plugin_meta__"]

253
aidraw.py Normal file
View File

@ -0,0 +1,253 @@
import time
import re
from collections import deque
import aiohttp
from aiohttp.client_exceptions import ClientConnectorError, ClientOSError
from argparse import Namespace
from asyncio import get_running_loop
from nonebot import get_bot, on_shell_command
from nonebot.adapters.onebot.v11 import GroupMessageEvent, MessageSegment, Bot
from nonebot.rule import ArgumentParser
from nonebot.permission import SUPERUSER
from nonebot.log import logger
from nonebot.params import ShellCommandArgs
from .config import config
from .utils.data import lowQuality, basetag, htags
from .backend import AIDRAW
from .extension.anlas import anlas_check, anlas_set
from .extension.daylimit import DayLimit
from .utils.save import save_img
from .utils.prepocess import prepocess_tags, combine_multi_args
from .version import version
from .utils import sendtosuperuser
cd = {}
gennerating = False
wait_list = deque([])
aidraw_parser = ArgumentParser()
aidraw_parser.add_argument("tags", nargs="*", help="标签")
aidraw_parser.add_argument("-r", "--resolution", "-形状",
help="画布形状/分辨率", dest="shape")
aidraw_parser.add_argument("-c", "--scale", "-服从",
type=float, help="对输入的服从度", dest="scale")
aidraw_parser.add_argument(
"-s", "--seed", "-种子", type=int, help="种子", dest="seed")
aidraw_parser.add_argument("-b", "--batch", "-数量",
type=int, default=1, help="生成数量", dest="batch")
aidraw_parser.add_argument("-t", "--steps", "-步数",
type=int, help="步数", dest="steps")
aidraw_parser.add_argument("-u", "--ntags", "-排除",
default=" ", nargs="*", help="负面标签", dest="ntags")
aidraw_parser.add_argument("-e", "--strength", "-强度",
type=float, help="修改强度", dest="strength")
aidraw_parser.add_argument("-n", "--noise", "-噪声",
type=float, help="修改噪声", dest="noise")
aidraw_parser.add_argument("-o", "--override", "-不优化",
action='store_true', help="不使用内置优化参数", dest="override")
aidraw_parser.add_argument("--sampler", "-采样器",
default="Euler a", nargs="+", help="设置采样器", dest="sampler")
aidraw_parser.add_argument("--hires", "-高清修复",
action='store_true', help="启用高清修复", dest="hires")
aidraw = on_shell_command(
".aidraw",
aliases={"绘画", "咏唱", "召唤", "约稿", "aidraw"},
parser=aidraw_parser,
priority=5
)
@aidraw.handle()
async def aidraw_get(bot: Bot, event: GroupMessageEvent, args: Namespace = ShellCommandArgs()):
user_id = str(event.user_id)
group_id = str(event.group_id)
# 判断是否禁用,若没禁用,进入处理流程
if await config.get_value(group_id, "on"):
message = ""
# 判断最大生成数量
if args.batch > config.novelai_max:
message = message+f",批量生成数量过多,自动修改为{config.novelai_max}"
args.batch = config.novelai_max
# 判断次数限制
if config.novelai_daylimit and not await SUPERUSER(bot, event):
left = DayLimit.count(user_id, args.batch)
if left == -1:
await aidraw.finish(f"今天你的次数不够了哦")
else:
message = message + f",今天你还能够生成{left}"
# 判断cd
nowtime = time.time()
deltatime = nowtime - cd.get(user_id, 0)
cd_ = int(await config.get_value(group_id, "cd"))
if deltatime < cd_:
await aidraw.finish(f"你冲的太快啦请休息一下吧剩余CD为{cd_ - int(deltatime)}s")
else:
cd[user_id] = nowtime
# 初始化参数
args.tags = await prepocess_tags(args.tags)
args.ntags = await prepocess_tags(args.ntags)
args.sampler = await combine_multi_args(args.sampler)
fifo = AIDRAW(user_id=user_id, group_id=group_id, **vars(args))
# 检测是否有18+词条
if not config.novelai_h:
pattern = re.compile(f"(\s|,|^)({htags})(\s|,|$)")
if (re.search(pattern, fifo.tags) is not None):
await aidraw.finish(f"H是不行的!")
if not args.override:
fifo.tags = basetag + await config.get_value(group_id, "tags") + "," + fifo.tags
fifo.ntags = lowQuality + fifo.ntags
# 以图生图预处理
img_url = ""
reply = event.reply
if reply:
for seg in reply.message['image']:
img_url = seg.data["url"]
for seg in event.message['image']:
img_url = seg.data["url"]
if img_url:
if config.novelai_paid:
async with aiohttp.ClientSession() as session:
logger.info(f"检测到图片,自动切换到以图生图,正在获取图片")
async with session.get(img_url) as resp:
fifo.add_image(await resp.read())
message = f",已切换至以图生图"+message
else:
await aidraw.finish(f"以图生图功能已禁用")
logger.debug(fifo)
# 初始化队列
if fifo.cost > 0:
anlascost = fifo.cost
hasanlas = await anlas_check(fifo.user_id)
if hasanlas >= anlascost:
await wait_fifo(fifo, anlascost, hasanlas - anlascost, message=message, bot=bot)
else:
await aidraw.finish(f"你的点数不足,你的剩余点数为{hasanlas}")
else:
await wait_fifo(fifo, message=message, bot=bot)
async def wait_fifo(fifo, anlascost=None, anlas=None, message="", bot=None):
# 创建队列
list_len = wait_len()
has_wait = f"排队中,你的前面还有{list_len}"+message
no_wait = "请稍等,图片生成中"+message
if anlas:
has_wait += f"\n本次生成消耗点数{anlascost},你的剩余点数为{anlas}"
no_wait += f"\n本次生成消耗点数{anlascost},你的剩余点数为{anlas}"
if config.novelai_limit:
await aidraw.send(has_wait if list_len > 0 else no_wait)
wait_list.append(fifo)
await fifo_gennerate(bot=bot)
else:
await aidraw.send(no_wait)
await fifo_gennerate(fifo, bot)
def wait_len():
# 获取剩余队列长度
list_len = len(wait_list)
if gennerating:
list_len += 1
return list_len
async def fifo_gennerate(fifo: AIDRAW = None, bot: Bot = None):
# 队列处理
global gennerating
if not bot:
bot = get_bot()
async def generate(fifo: AIDRAW):
id = fifo.user_id if config.novelai_antireport else bot.self_id
resp = await bot.get_group_member_info(group_id=fifo.group_id, user_id=fifo.user_id)
nickname = resp["card"] or resp["nickname"]
# 开始生成
logger.info(
f"队列剩余{wait_len()}人 | 开始生成:{fifo}")
try:
im = await _run_gennerate(fifo)
except Exception as e:
logger.exception("生成失败")
message = f"生成失败,"
for i in e.args:
message += str(i)
await bot.send_group_msg(
message=message,
group_id=fifo.group_id
)
else:
logger.info(f"队列剩余{wait_len()}人 | 生成完毕:{fifo}")
if await config.get_value(fifo.group_id, "pure"):
message = MessageSegment.at(fifo.user_id)
for i in im["image"]:
message += i
message_data = await bot.send_group_msg(
message=message,
group_id=fifo.group_id,
)
else:
message = []
for i in im:
message.append(MessageSegment.node_custom(
id, nickname, i))
message_data = await bot.send_group_forward_msg(
messages=message,
group_id=fifo.group_id,
)
revoke = await config.get_value(fifo.group_id, "revoke")
if revoke:
message_id = message_data["message_id"]
loop = get_running_loop()
loop.call_later(
revoke,
lambda: loop.create_task(
bot.delete_msg(message_id=message_id)),
)
if fifo:
await generate(fifo)
if not gennerating:
logger.info("队列开始")
gennerating = True
while len(wait_list) > 0:
fifo = wait_list.popleft()
try:
await generate(fifo)
except:
pass
gennerating = False
logger.info("队列结束")
await version.check_update()
async def _run_gennerate(fifo: AIDRAW):
# 处理单个请求
try:
await fifo.post()
except ClientConnectorError:
await sendtosuperuser(f"远程服务器拒绝连接,请检查配置是否正确,服务器是否已经启动")
raise RuntimeError(f"远程服务器拒绝连接,请检查配置是否正确,服务器是否已经启动")
except ClientOSError:
await sendtosuperuser(f"远程服务器崩掉了欸……")
raise RuntimeError(f"服务器崩掉了欸……请等待主人修复吧")
# 若启用ai检定取消注释下行代码并将构造消息体部分注释
# message = await check_safe_method(fifo, img_bytes, message)
# 构造消息体并保存图片
message = f"{config.novelai_mode}绘画完成~"
for i in fifo.result:
await save_img(fifo, i, fifo.group_id)
message += MessageSegment.image(i)
for i in fifo.format():
message += MessageSegment.text(i)
# 扣除点数
if fifo.cost > 0:
await anlas_set(fifo.user_id, -fifo.cost)
return message

0
amusement/ramdomgirl.py Normal file
View File

0
amusement/wordbank.py Normal file
View File

20
backend/__init__.py Normal file
View File

@ -0,0 +1,20 @@
from ..config import config
"""def AIDRAW():
if config.novelai_mode=="novelai":
from .novelai import AIDRAW
elif config.novelai_mode=="naifu":
from .naifu import AIDRAW
elif config.novelai_mode=="sd":
from .sd import AIDRAW
else:
raise RuntimeError(f"错误的mode设置支持的字符串为'novelai','naifu','sd'")
return AIDRAW()"""
if config.novelai_mode=="novelai":
from .novelai import AIDRAW
elif config.novelai_mode=="naifu":
from .naifu import AIDRAW
elif config.novelai_mode=="sd":
from .sd import AIDRAW
else:
raise RuntimeError(f"错误的mode设置支持的字符串为'novelai','naifu','sd'")

266
backend/base.py Normal file
View File

@ -0,0 +1,266 @@
import asyncio
import base64
import random
import time
from io import BytesIO
import aiohttp
from nonebot import get_driver
from nonebot.log import logger
from PIL import Image
from ..config import config
from ..utils import png2jpg
from ..utils.data import shapemap
class AIDRAW_BASE:
max_resolution: int = 16
sampler: str
def __init__(
self,
user_id: str,
group_id: str,
tags: str = "",
ntags: str = "",
seed: int = None,
scale: int = None,
steps: int = None,
batch: int = None,
strength: float = None,
noise: float = None,
shape: str = "p",
model: str = None,
sampler: str = "",
hires: bool = False,
**kwargs,
):
"""
AI绘画的核心部分,将与服务器通信的过程包装起来,并方便扩展服务器类型
:user_id: 用户id,必须
:group_id: 群聊id,如果是私聊则应置为0,必须
:tags: 图像的标签
:ntags: 图像的反面标签
:seed: 生成的种子不指定的情况下随机生成
:scale: 标签的参考度值越高越贴近于标签,但可能产生过度锐化范围为0-30,默认11
:steps: 训练步数范围为1-50,默认28.以图生图时强制50
:batch: 同时生成数量
:strength: 以图生图时使用,变化的强度范围为0-1,默认0.7
:noise: 以图生图时使用,变化的噪音,数值越大细节越多,但可能产生伪影,不建议超过strength范围0-1,默认0.2
:shape: 图像的形状支持"p""s""l"三种同时支持以"x"分割宽高的指定分辨率
该值会被设置限制并不会严格遵守输入
类初始化后,该参数会被拆分为:width::height:
:model: 指定的模型模型名称在配置文件中手动设置不指定模型则按照负载均衡自动选择
AIDRAW还包含了以下几种内置的参数
:status: 记录了AIDRAW的状态,默认为0等待中(处理中)
非0的值为运行完成后的状态值,200和201为正常运行,其余值为产生错误
:result: 当正常运行完成后,该参数为一个包含了生成图片bytes信息的数组
:maxresolution: 一般不用管用于限制不同服务器的最大分辨率
如果你的SD经过了魔改支持更大分辨率可以修改该值并重新设置宽高
:cost: 记录了本次生成需要花费多少点数自动计算
:signal: asyncio.Event类,可以作为信号使用仅占位需要自行实现相关方法
"""
self.status: int = 0
self.result: list = []
self.signal: asyncio.Event = None
self.model = model
self.time = time.strftime("%Y-%m-%d %H:%M:%S")
self.user_id: str = user_id
self.tags: str = tags
self.seed: list[int] = [seed or random.randint(0, 4294967295)]
self.group_id: str = group_id
self.scale: int = int(scale or 11)
self.strength: float = strength or 0.7
self.batch: int = batch or 1
self.steps: int = steps or 28
self.noise: float = noise or 0.2
self.ntags: str = ntags
self.img2img: bool = False
self.image: str = None
self.width, self.height = self.extract_shape(shape)
self.sampler = sampler
self.hires = hires
# 数值合法检查
if self.steps <= 0 or self.steps > (50 if not config.novelai_paid else 28):
self.steps = 28
if self.strength < 0 or self.strength > 1:
self.strength = 0.7
if self.noise < 0 or self.noise > 1:
self.noise = 0.2
if self.scale <= 0 or self.scale > 30:
self.scale = 11
# 多图时随机填充剩余seed
for i in range(self.batch - 1):
self.seed.append(random.randint(0, 4294967295))
# 计算cost
self.update_cost()
def extract_shape(self, shape: str):
"""
将shape拆分为width和height
"""
if shape:
if "x" in shape:
width, height, *_ = shape.split("x")
if width.isdigit() and height.isdigit():
return self.shape_set(int(width), int(height))
else:
return shapemap.get(shape)
else:
return shapemap.get(shape)
else:
return (512, 768)
def update_cost(self):
"""
更新cost
"""
if config.novelai_paid == 1:
anlas = 0
if (self.width * self.height > 409600) or self.image or self.batch > 1:
anlas = round(
self.width
* self.height
* self.strength
* self.batch
* self.steps
/ 2293750
)
if anlas < 2:
anlas = 2
if self.user_id in get_driver().config.superusers:
self.cost = 0
else:
self.cost = anlas
elif config.novelai_paid == 2:
anlas = round(
self.width
* self.height
* self.strength
* self.batch
* self.steps
/ 2293750
)
if anlas < 2:
anlas = 2
if self.user_id in get_driver().config.superusers:
self.cost = 0
else:
self.cost = anlas
else:
self.cost = 0
def add_image(self, image: bytes):
"""
向类中添加图片将其转化为以图生图模式
也可用于修改类中已经存在的图片
"""
# 根据图片重写长宽
tmpfile = BytesIO(image)
image_ = Image.open(tmpfile)
width, height = image_.size
self.width, self.height = self.shape_set(width, height)
self.image = str(base64.b64encode(image), "utf-8")
self.steps = 50
self.img2img = True
self.update_cost()
def shape_set(self, width: int, height: int):
"""
设置宽高
"""
limit = 1024 if config.paid else 640
if width * height > pow(min(config.novelai_size, limit), 2):
if width <= height:
ratio = height / width
width: float = config.novelai_size / pow(ratio, 0.5)
height: float = width * ratio
else:
ratio = width / height
height: float = config.novelai_size / pow(ratio, 0.5)
width: float = height * ratio
base = round(max(width, height) / 64)
if base > self.max_resolution:
base = self.max_resolution
if width <= height:
return (round(width / height * base) * 64, 64 * base)
else:
return (64 * base, round(height / width * base) * 64)
async def post_(self, header: dict, post_api: str, json: dict):
"""
向服务器发送请求的核心函数不要直接调用请使用post函数
:header: 请求头
:post_api: 请求地址
:json: 请求体
"""
# 请求交互
async with aiohttp.ClientSession(headers=header) as session:
# 向服务器发送请求
async with session.post(post_api, json=json) as resp:
if resp.status not in [200, 201]:
logger.error(await resp.text())
raise RuntimeError(f"与服务器沟通时发生{resp.status}错误")
img = await self.fromresp(resp)
logger.debug(f"获取到返回图片,正在处理")
# 将图片转化为jpg
if config.novelai_save == 1:
image_new = await png2jpg(img)
else:
image_new = base64.b64decode(img)
self.result.append(image_new)
return image_new
async def fromresp(self, resp):
"""
处理请求的返回内容不要直接调用请使用post函数
"""
img: str = await resp.text()
return img.split("data:")[1]
def run(self):
"""
运行核心函数发送请求并处理
"""
pass
def keys(self):
return (
"seed",
"scale",
"strength",
"noise",
"sampler",
"model",
"steps",
"width",
"height",
"img2img",
)
def __getitem__(self, item):
return getattr(self, item)
def format(self):
dict_self = dict(self)
list = []
str = ""
for i, v in dict_self.items():
str += f"{i}={v}\n"
list.append(str)
list.append(f"tags={self.tags}\n")
list.append(f"ntags={self.ntags}")
return list
def __repr__(self):
return (
f"time={self.time}\nuser_id={self.user_id}\ngroup_id={self.group_id}\ncost={self.cost}\nbatch={self.batch}\n"
+ "".join(self.format())
)
def __str__(self):
return self.__repr__().replace("\n", ";")

34
backend/naifu.py Normal file
View File

@ -0,0 +1,34 @@
from .base import AIDRAW_BASE
from ..config import config
class AIDRAW(AIDRAW_BASE):
"""队列中的单个请求"""
async def post(self):
header = {
"content-type": "application/json",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
}
site=config.novelai_site or "127.0.0.1:6969"
post_api="http://"+site + "/generate-stream"
for i in range(self.batch):
parameters = {
"prompt":self.tags,
"width": self.width,
"height": self.height,
"qualityToggle": False,
"scale": self.scale,
"sampler": self.sampler,
"steps": self.steps,
"seed": self.seed[i],
"n_samples": 1,
"ucPreset": 0,
"uc": self.ntags,
}
if self.img2img:
parameters.update({
"image": self.image,
"strength": self.strength,
"noise": self.noise
})
await self.post_(header, post_api,parameters)
return self.result

44
backend/novelai.py Normal file
View File

@ -0,0 +1,44 @@
from ..config import config
from .base import AIDRAW_BASE
class AIDRAW(AIDRAW_BASE):
"""队列中的单个请求"""
model: str = "nai-diffusion" if config.novelai_h else "safe-diffusion"
async def post(self):
# 获取请求体
header = {
"authorization": "Bearer " + config.novelai_token,
":authority": "https://api.novelai.net",
":path": "/ai/generate-image",
"content-type": "application/json",
"referer": "https://novelai.net",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
}
post_api = "https://api.novelai.net/ai/generate-image"
for i in range(self.batch):
parameters = {
"width": self.width,
"height": self.height,
"qualityToggle": False,
"scale": self.scale,
"sampler": self.sampler,
"steps": self.steps,
"seed": self.seed[i],
"n_samples": 1,
"ucPreset": 0,
"uc": self.ntags,
}
if self.img2img:
parameters.update({
"image": self.image,
"strength": self.strength,
"noise": self.noise
})
json= {
"input": self.tags,
"model": self.model,
"parameters": parameters
}
await self.post_(header, post_api,json)
return self.result

43
backend/sd.py Normal file
View File

@ -0,0 +1,43 @@
from .base import AIDRAW_BASE
from ..config import config
class AIDRAW(AIDRAW_BASE):
"""队列中的单个请求"""
max_resolution: int = 32
async def fromresp(self, resp):
img: dict = await resp.json()
return img["images"][0]
async def post(self):
site=config.novelai_site or "127.0.0.1:7860"
header = {
"content-type": "application/json",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36",
}
post_api = f"http://{site}/sdapi/v1/img2img" if self.img2img else f"http://{site}/sdapi/v1/txt2img"
for i in range(self.batch):
parameters = {
"prompt": self.tags,
"seed": self.seed[i],
"steps": self.steps,
"cfg_scale": self.scale,
"width": self.width,
"height": self.height,
"sampler_name": self.sampler,
"negative_prompt": self.ntags,
"enable_hr": self.hires,
"denoising_strength": 0.7,
"hr_scale": 2,
"hr_upscaler": "Latent"
}
print("向API发送以下参数")
print(parameters)
if self.img2img:
parameters.update({
"init_images": ["data:image/jpeg;base64,"+self.image],
"denoising_strength": self.strength,
})
await self.post_(header, post_api, parameters)
return self.result

155
config.py Normal file
View File

@ -0,0 +1,155 @@
import json
from pathlib import Path
import aiofiles
from nonebot import get_driver
from nonebot.log import logger
from pydantic import BaseSettings, validator
from pydantic.fields import ModelField
jsonpath = Path("data/novelai/config.json").resolve()
nickname = list(get_driver().config.nickname)[0] if len(
get_driver().config.nickname) else "nonebot-plugin-novelai"
class Config(BaseSettings):
# 服务器设置
novelai_token: str = "" # 官网的token
# novelai: dict = {"novelai":""}# 你的服务器地址包含端口不包含http头例:127.0.0.1:6969
novelai_mode: str = "novelai"
novelai_site: str = ""
# 后台设置
novelai_save: int = 1 # 是否保存图片至本地,0为不保存1保存2同时保存追踪信息
novelai_paid: int = 0 # 0为禁用付费模式1为点数制2为不限制
novelai_pure: bool = False # 是否启用简洁返回模式只返回图片不返回tag等数据
novelai_limit: bool = True # 是否开启限速
novelai_daylimit: int = 0 # 每日次数限制0为禁用
novelai_h: bool = False # 是否允许H
novelai_antireport: bool = True # 玄学选项。开启后合并消息内发送者将会显示为调用指令的人而不是bot
novelai_max: int = 3 # 每次能够生成的最大数量
# 允许生成的图片最大分辨率,对应(值)^2.默认为1024即1024*1024。如果服务器比较寄建议改成640640*640或者根据能够承受的情况修改。naifu和novelai会分别限制最大长宽为1024
novelai_size: int = 1024
# 可运行更改的设置
novelai_tags: str = "" # 内置的tag
novelai_ntags: str = "" # 内置的反tag
novelai_cd: int = 60 # 默认的cd
novelai_on: bool = True # 是否全局开启
novelai_revoke: int = 0 # 是否自动撤回该值不为0时则为撤回时间
# 翻译API设置
bing_key: str = None # bing的翻译key
deepl_key: str = None # deepL的翻译key
# 允许单群设置的设置
def keys(cls):
return ("novelai_cd", "novelai_tags", "novelai_on", "novelai_ntags", "novelai_revoke")
def __getitem__(cls, item):
return getattr(cls, item)
@validator("novelai_cd", "novelai_max")
def non_negative(cls, v: int, field: ModelField):
if v < 1:
return field.default
return v
@validator("novelai_paid")
def paid(cls, v: int, field: ModelField):
if v < 0:
return field.default
elif v > 3:
return field.default
return v
class Config:
extra = "ignore"
async def set_enable(cls, group_id, enable):
# 设置分群启用
await cls.__init_json()
now = await cls.get_value(group_id, "on")
logger.debug(now)
if now:
if enable:
return f"aidraw已经处于启动状态"
else:
if await cls.set_value(group_id, "on", "false"):
return f"aidraw已关闭"
else:
if enable:
if await cls.set_value(group_id, "on", "true"):
return f"aidraw开始运行"
else:
return f"aidraw已经处于关闭状态"
async def __init_json(cls):
# 初始化设置文件
if not jsonpath.exists():
jsonpath.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(jsonpath, "w+") as f:
await f.write("{}")
async def get_value(cls, group_id, arg: str):
# 获取设置值
group_id = str(group_id)
arg_ = arg if arg.startswith("novelai_") else "novelai_" + arg
if arg_ in cls.keys():
await cls.__init_json()
async with aiofiles.open(jsonpath, "r") as f:
jsonraw = await f.read()
configdict: dict = json.loads(jsonraw)
return configdict.get(group_id, {}).get(arg_, dict(cls)[arg_])
else:
return None
async def get_groupconfig(cls, group_id):
# 获取当群所有设置值
group_id = str(group_id)
await cls.__init_json()
async with aiofiles.open(jsonpath, "r") as f:
jsonraw = await f.read()
configdict: dict = json.loads(jsonraw)
baseconfig = {}
for i in cls.keys():
value = configdict.get(group_id, {}).get(
i, dict(cls)[i])
baseconfig[i] = value
logger.debug(baseconfig)
return baseconfig
async def set_value(cls, group_id, arg: str, value: str):
"""设置当群设置值"""
# 将值转化为bool和int
if value.isdigit():
value: int = int(value)
elif value.lower() == "false":
value = False
elif value.lower() == "true":
value = True
group_id = str(group_id)
arg_ = arg if arg.startswith("novelai_") else "novelai_" + arg
# 判断是否合法
if arg_ in cls.keys() and isinstance(value, type(dict(cls)[arg_])):
await cls.__init_json()
# 读取文件
async with aiofiles.open(jsonpath, "r") as f:
jsonraw = await f.read()
configdict: dict = json.loads(jsonraw)
# 设置值
groupdict = configdict.get(group_id, {})
if value == "default":
groupdict[arg_] = False
else:
groupdict[arg_] = value
configdict[group_id] = groupdict
# 写入文件
async with aiofiles.open(jsonpath, "w") as f:
jsonnew = json.dumps(configdict)
await f.write(jsonnew)
return True
else:
logger.debug(f"不正确的赋值,{arg_},{value},{type(value)}")
return False
config = Config(**get_driver().config.dict())
logger.info(f"加载config完成" + str(config))

78
extension/anlas.py Normal file
View File

@ -0,0 +1,78 @@
from pathlib import Path
import json
import aiofiles
from nonebot.adapters.onebot.v11 import Bot,GroupMessageEvent, Message, MessageSegment
from nonebot.permission import SUPERUSER
from nonebot.params import CommandArg
from nonebot import on_command, get_driver
jsonpath = Path("data/novelai/anlas.json").resolve()
setanlas = on_command(".anlas")
@setanlas.handle()
async def anlas_handle(bot:Bot,event: GroupMessageEvent, args: Message = CommandArg()):
atlist = []
user_id = str(event.user_id)
for seg in event.original_message["at"]:
atlist.append(seg.data["qq"])
messageraw = args.extract_plain_text().strip()
if not messageraw or messageraw == "help":
await setanlas.finish(f"点数计算方法(四舍五入):分辨率*数量*强度/45875\n.anlas+数字+@某人 将自己的点数分给对方\n.anlas check 查看自己的点数")
elif messageraw == "check":
if await SUPERUSER(bot,event):
await setanlas.finish(f"Master不需要点数哦")
else:
anlas = await anlas_check(user_id)
await setanlas.finish(f"你的剩余点数为{anlas}")
if atlist:
at = atlist[0]
if messageraw.isdigit():
anlas_change = int(messageraw)
if anlas_change > 1000:
await setanlas.finish(f"一次能给予的点数不超过1000")
if await SUPERUSER(bot,event):
_, result = await anlas_set(at, anlas_change)
message = f"分配完成:" + \
MessageSegment.at(at)+f"的剩余点数为{result}"
else:
result, user_anlas = await anlas_set(user_id, -anlas_change)
if result:
_, at_anlas = await anlas_set(at, anlas_change)
message = f"分配完成:\n"+MessageSegment.at(
user_id)+f"的剩余点数为{user_anlas}\n"+MessageSegment.at(at)+f"的剩余点数为{at_anlas}"
await setanlas.finish(message)
else:
await setanlas.finish(f"分配失败:点数不足,你的剩余点数为{user_anlas}")
await setanlas.finish(message)
else:
await setanlas.finish(f"请以正整数形式输入点数")
else:
await setanlas.finish(f"请@你希望给予点数的人")
async def anlas_check(user_id):
if not jsonpath.exists():
jsonpath.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(jsonpath, "w+")as f:
await f.write("{}")
async with aiofiles.open(jsonpath, "r") as f:
jsonraw = await f.read()
anlasdict: dict = json.loads(jsonraw)
anlas = anlasdict.get(user_id, 0)
return anlas
async def anlas_set(user_id, change):
oldanlas = await anlas_check(user_id)
newanlas = oldanlas+change
if newanlas < 0:
return False, oldanlas
anlasdict = {}
async with aiofiles.open(jsonpath, "r") as f:
jsonraw = await f.read()
anlasdict: dict = json.loads(jsonraw)
anlasdict[user_id] = newanlas
async with aiofiles.open(jsonpath, "w+") as f:
jsonnew = json.dumps(anlasdict)
await f.write(jsonnew)
return True, newanlas

20
extension/daylimit.py Normal file
View File

@ -0,0 +1,20 @@
import time
from ..config import config
class DayLimit():
day: int = time.localtime(time.time()).tm_yday
data: dict = {}
@classmethod
def count(cls, user: str, num):
day_ = time.localtime(time.time()).tm_yday
if day_ != cls.day:
cls.day = day_
cls.data = {}
count: int = cls.data.get(user, 0)+num
if count > config.novelai_daylimit:
return -1
else:
cls.data[user] = count
return config.novelai_daylimit-count

41
extension/deepdanbooru.py Normal file
View File

@ -0,0 +1,41 @@
import aiohttp
import base64
from nonebot import on_command
from nonebot.adapters.onebot.v11 import GroupMessageEvent, MessageSegment
from nonebot.log import logger
from .translation import translate
deepdanbooru = on_command(".gettag", aliases={"鉴赏", "查书"})
@deepdanbooru.handle()
async def deepdanbooru_handle(event: GroupMessageEvent):
url = ""
for seg in event.message['image']:
url = seg.data["url"]
if url:
async with aiohttp.ClientSession() as session:
logger.info(f"正在获取图片")
async with session.get(url) as resp:
bytes = await resp.read()
str_img = str(base64.b64encode(bytes), "utf-8")
message = MessageSegment.at(event.user_id)
start = "data:image/jpeg;base64,"
str0 = start+str_img
async with aiohttp.ClientSession() as session:
async with session.post('https://mayhug-rainchan-anime-image-label.hf.space/api/predict/', json={"data": [str0, 0.6,"ResNet101"]}) as resp:
if resp.status != 200:
await deepdanbooru.finish(f"识别失败,错误代码为{resp.status}")
jsonresult = await resp.json()
data = jsonresult['data'][0]
logger.info(f"TAG查询完毕")
tags = ""
for label in data['confidences']:
tags = tags+label["label"]+","
tags_ch = await translate(tags.replace("_", " "), "zh")
if tags_ch == tags.replace("_", " "):
message = message+tags
message = message+tags+f"\n机翻结果:"+tags_ch
await deepdanbooru.finish(message)
else:
await deepdanbooru.finish(f"未找到图片")

113
extension/translation.py Normal file
View File

@ -0,0 +1,113 @@
import aiohttp
from ..config import config
from nonebot.log import logger
async def translate(text: str, to: str):
# en,jp,zh
result = await translate_deepl(text, to) or await translate_bing(text, to) or await translate_google_proxy(text, to) or await translate_youdao(text, to)
if result:
return result
else:
logger.error(f"未找到可用的翻译引擎!")
return text
async def translate_bing(text: str, to: str):
"""
en,jp,zh_Hans
"""
if to == "zh":
to = "zh-Hans"
key = config.bing_key
if not key:
return None
header = {
"Ocp-Apim-Subscription-Key": key,
"Content-Type": "application/json",
}
async with aiohttp.ClientSession() as session:
body = [{'text': text}]
params = {
"api-version": "3.0",
"to": to,
"profanityAction": "Deleted",
}
async with session.post('https://api.cognitive.microsofttranslator.com/translate', json=body, params=params, headers=header) as resp:
if resp.status != 200:
logger.error(f"Bing翻译接口调用失败,错误代码{resp.status},{await resp.text()}")
return None
jsonresult = await resp.json()
result=jsonresult[0]["translations"][0]["text"]
logger.debug(f"Bing翻译启动获取到{text},翻译后{result}")
return result
async def translate_deepl(text: str, to: str):
"""
EN,JA,ZH
"""
to = to.upper()
key = config.deepl_key
if not key:
return None
async with aiohttp.ClientSession() as session:
params = {
"auth_key":key,
"text": text,
"target_lang": to,
}
async with session.get('https://api-free.deepl.com/v2/translate', params=params) as resp:
if resp.status != 200:
logger.error(f"DeepL翻译接口调用失败,错误代码{resp.status},{await resp.text()}")
return None
jsonresult = await resp.json()
result=jsonresult["translations"][0]["text"]
logger.debug(f"DeepL翻译启动获取到{text},翻译后{result}")
return result
async def translate_youdao(input: str, type: str):
"""
默认auto
ZH_CH2EN 中译英
EN2ZH_CN 英译汉
"""
if type == "zh":
type = "EN2ZH_CN"
elif type == "en":
type = "ZH_CH2EN"
async with aiohttp.ClientSession() as session:
data = {
'doctype': 'json',
'type': type,
'i': input
}
async with session.post("http://fanyi.youdao.com/translate", data=data) as resp:
if resp.status != 200:
logger.error(f"有道翻译接口调用失败,错误代码{resp.status},{await resp.text()}")
return None
result = await resp.json()
result=result["translateResult"][0][0]["tgt"]
logger.debug(f"有道翻译启动,获取到{input},翻译后{result}")
return result
async def translate_google_proxy(input: str, to: str):
"""
en,jp,zh 需要来源语言
"""
if to == "zh":
from_ = "en"
else:
from_="zh"
async with aiohttp.ClientSession()as session:
data = {"data": [input, from_, to]}
async with session.post("https://hf.space/embed/mikeee/gradio-gtr/+/api/predict", json=data)as resp:
if resp.status != 200:
logger.error(f"谷歌代理翻译接口调用失败,错误代码{resp.status},{await resp.text()}")
return None
result = await resp.json()
result=result["data"][0]
logger.debug(f"谷歌代理翻译启动,获取到{input},翻译后{result}")
return result

19
fifo.py Normal file
View File

@ -0,0 +1,19 @@
from collections import deque
class FIFO():
gennerating: dict={}
queue: deque = deque([])
@classmethod
def len(cls):
return len(cls.queue)+1 if cls.gennerating else len(cls.queue)
@classmethod
async def add(cls, aidraw):
cls.queue.append(aidraw)
await cls.gennerate()
@classmethod
async def gennerate(cls):
pass

0
locales/__init__.py Normal file
View File

0
locales/en.py Normal file
View File

0
locales/jp.py Normal file
View File

0
locales/moe_jp.py Normal file
View File

0
locales/moe_zh.py Normal file
View File

0
locales/zh.py Normal file
View File

42
manage.py Normal file
View File

@ -0,0 +1,42 @@
from nonebot.adapters.onebot.v11 import GROUP_ADMIN, GROUP_OWNER
from nonebot.adapters.onebot.v11 import GroupMessageEvent, Bot
from nonebot.permission import SUPERUSER
from nonebot.params import RegexGroup
from nonebot import on_regex
from nonebot.log import logger
from .config import config
on = on_regex(f"(?:^\.aidraw|^绘画|^aidraw)[ ]*(on$|off$|开启$|关闭$)",
priority=4, block=True)
set = on_regex(
"(?:^\.aidraw set|^绘画设置|^aidraw set)[ ]*([a-z]*)[ ]*(.*)", priority=4, block=True)
@set.handle()
async def set_(bot: Bot, event: GroupMessageEvent, args= RegexGroup()):
if await GROUP_ADMIN(bot, event) or await GROUP_OWNER(bot, event) or await SUPERUSER(bot, event):
if args[0] and args[1]:
key, value = args
await set.finish(f"设置群聊{key}{value}完成" if await config.set_value(event.group_id, key,
value) else f"不正确的赋值")
else:
group_config = await config.get_groupconfig(event.group_id)
message = "当前群的设置为\n"
for i, v in group_config.items():
message += f"{i}:{v}\n"
await set.finish(message)
else:
await set.send(f"权限不足!")
@on.handle()
async def on_(bot: Bot, event: GroupMessageEvent, args=RegexGroup()):
if await GROUP_ADMIN(bot, event) or await GROUP_OWNER(bot, event) or await SUPERUSER(bot, event):
if args[0] in ["on", "开启"]:
set = True
else:
set = False
result = await config.set_enable(event.group_id, set)
logger.info(result)
await on.finish(result)
else:
await on.send(f"权限不足!")

54
outofdate/explicit_api.py Normal file
View File

@ -0,0 +1,54 @@
from ..config import config, nickname
from ..utils.save import save_img
from io import BytesIO
import base64
import aiohttp
import asyncio
from nonebot.adapters.onebot.v11 import MessageSegment
from nonebot.log import logger
async def check_safe_method(fifo, img_bytes, message):
if config.novelai_h:
for i in img_bytes:
await save_img(fifo, i)
message += MessageSegment.image(i)
else:
nsfw_count = 0
for i in img_bytes:
try:
label = await check_safe(i)
except RuntimeError as e:
logger.error(f"NSFWAPI调用失败错误代码为{e.args}")
label = "unknown"
if label != "explicit":
message += MessageSegment.image(i)
else:
nsfw_count += 1
await save_img(fifo, i, label)
if nsfw_count > 0:
message += f"\n{nsfw_count}张图片太涩了,{nickname}已经帮你吃掉了哦"
return message
async def check_safe(img_bytes: BytesIO):
# 检查图片是否安全
start = "data:image/jpeg;base64,"
image = img_bytes.getvalue()
image = str(base64.b64encode(image), "utf-8")
str0 = start + image
# 重试三次
for i in range(3):
async with aiohttp.ClientSession() as session:
# 调用API
async with session.post('https://hf.space/embed/mayhug/rainchan-image-porn-detection/api/predict/',
json={"data": [str0]}) as resp:
if resp.status == 200:
jsonresult = await resp.json()
break
else:
await asyncio.sleep(2)
error = resp.status
else:
raise RuntimeError(error)
return jsonresult["data"][0]["label"]

48
utils/__init__.py Normal file
View File

@ -0,0 +1,48 @@
from io import BytesIO
from PIL import Image
import re
import aiohttp
import base64
async def check_last_version(package: str):
# 检查包的最新版本
async with aiohttp.ClientSession() as session:
async with session.get("https://pypi.org/simple/"+package) as resp:
text = await resp.text()
pattern = re.compile("-(\d.*?).tar.gz")
pypiversion = re.findall(pattern, text)[-1]
return pypiversion
async def compare_version(old: str, new: str):
# 比较两个版本哪个最新
oldlist = old.split(".")
newlist = new.split(".")
for i in range(len(oldlist)):
if int(newlist[i]) > int(oldlist[i]):
return True
return False
async def sendtosuperuser(message):
# 将消息发送给superuser
from nonebot import get_bot, get_driver
import asyncio
superusers = get_driver().config.superusers
bot = get_bot()
for superuser in superusers:
await bot.call_api('send_msg', **{
'message': message,
'user_id': superuser,
})
await asyncio.sleep(5)
async def png2jpg(raw: bytes):
raw:BytesIO = BytesIO(base64.b64decode(raw))
img_PIL = Image.open(raw).convert("RGB")
image_new = BytesIO()
img_PIL.save(image_new, format="JPEG", quality=95)
image_new=image_new.getvalue()
return image_new

20
utils/data.py Normal file
View File

@ -0,0 +1,20 @@
# 基础优化tag
basetag = "masterpiece, best quality,"
# 基础排除tag
lowQuality = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, pubic hair,long neck,blurry"
# 屏蔽词
htags = "nsfw|nude|naked|nipple|blood|censored|vagina|gag|gokkun|hairjob|tentacle|oral|fellatio|areolae|lactation|paizuri|piercing|sex|footjob|masturbation|hips|penis|testicles|ejaculation|cum|tamakeri|pussy|pubic|clitoris|mons|cameltoe|grinding|crotch|cervix|cunnilingus|insertion|penetration|fisting|fingering|peeing|ass|buttjob|spanked|anus|anal|anilingus|enema|x-ray|wakamezake|humiliation|tally|futa|incest|twincest|pegging|femdom|ganguro|bestiality|gangbang|3P|tribadism|molestation|voyeurism|exhibitionism|rape|spitroast|cock|69|doggystyle|missionary|virgin|shibari|bondage|bdsm|rope|pillory|stocks|bound|hogtie|frogtie|suspension|anal|dildo|vibrator|hitachi|nyotaimori|vore|amputee|transformation|bloody"
shapemap = {
"square": [640, 640],
"s": [640, 640],
"": [640, 640],
"portrait": [512, 768],
"p": [512, 768],
"": [512, 768],
"landscape": [768, 512],
"l": [768, 512],
"": [768, 512]
}

38
utils/prepocess.py Normal file
View File

@ -0,0 +1,38 @@
import re
from ..extension.translation import translate
escape_table = {
'&#91;': '[',
'&#93;': ']'
}
async def prepocess_tags(tags: list[str]):
tags: str = "".join([i+" " for i in tags if isinstance(i, str)])
# 去除CQ码
tags = re.sub("\[CQ[^\s]*?]", "", tags)
# 检测中文
taglist = tags.split(",")
tagzh = ""
tags_ = ""
for i in taglist:
if re.search('[\u4e00-\u9fa5]', tags):
tagzh += f"{i},"
else:
tags_ += f"{i},"
if tagzh:
tags_en = await translate(tagzh, "en")
if tags_en == tagzh:
return ""
else:
tags_ += tags_en
return await fix_char_escape(tags_)
async def combine_multi_args(args: list[str]):
return ' '.join(args)
async def fix_char_escape(tags: str):
for escape, raw in escape_table.items():
tags = tags.replace(escape, raw)
return tags

17
utils/save.py Normal file
View File

@ -0,0 +1,17 @@
from ..config import config
from pathlib import Path
import hashlib
import aiofiles
path = Path("data/novelai/output").resolve()
async def save_img(fifo, img_bytes: bytes, extra: str = "unknown"):
# 存储图片
if config.novelai_save:
path_ = path / extra
path_.mkdir(parents=True, exist_ok=True)
hash = hashlib.md5(img_bytes).hexdigest()
file = (path_ / hash).resolve()
async with aiofiles.open(str(file) + ".jpg", "wb") as f:
await f.write(img_bytes)
if config.novelai_save==2:
async with aiofiles.open(str(file) + ".txt", "w") as f:
await f.write(repr(fifo))

50
version.py Normal file
View File

@ -0,0 +1,50 @@
import time
from importlib.metadata import version
from nonebot.log import logger
from .utils import check_last_version, sendtosuperuser, compare_version
class Version():
version: str # 当前版本
lastcheck: float = 0 # 上次检查时间
ispushed: bool = True # 是否已经推送
latest: str = "0.0.0" # 最新版本
package = "nonebot-plugin-novelai"
url = "https://sena-nana.github.io/MutsukiDocs/update/novelai/"
def __init__(self):
# 初始化当前版本
try:
self.version = version(self.package)
except:
self.version = "0.5.7"
async def check_update(self):
"""检查更新,并推送"""
# 每日检查
if time.time() - self.lastcheck > 80000:
update = await check_last_version(self.package)
# 判断是否重复检查
if await compare_version(self.latest, update):
self.latest = update
# 判断是否是新版本
if await compare_version(self.version, self.latest):
logger.info(self.push_txt())
self.ispushed = False
else:
logger.info(f"novelai插件检查版本完成当前版本{self.version},最新版本{self.latest}")
else:
logger.info(f"novelai插件检查版本完成当前版本{self.version},最新版本{self.latest}")
self.lastcheck = time.time()
# 如果没有推送,则启动推送流程
if not self.ispushed:
await sendtosuperuser(self.push_txt())
self.ispushed = True
def push_txt(self):
# 获取推送文本
logger.debug(self.__dict__)
return f"novelai插件检测到新版本{self.latest},当前版本{self.version},请使用pip install --upgrade {self.package}命令升级,更新日志:{self.url}"
version = Version()