新增指定sampler, hires的功能
修复tags中方括号编码不正常的问题
This commit is contained in:
		
							
								
								
									
										48
									
								
								utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import re
 | 
			
		||||
import aiohttp
 | 
			
		||||
import base64
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def check_last_version(package: str):
 | 
			
		||||
    # 检查包的最新版本
 | 
			
		||||
    async with aiohttp.ClientSession() as session:
 | 
			
		||||
        async with session.get("https://pypi.org/simple/"+package) as resp:
 | 
			
		||||
            text = await resp.text()
 | 
			
		||||
            pattern = re.compile("-(\d.*?).tar.gz")
 | 
			
		||||
            pypiversion = re.findall(pattern, text)[-1]
 | 
			
		||||
    return pypiversion
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def compare_version(old: str, new: str):
 | 
			
		||||
    # 比较两个版本哪个最新
 | 
			
		||||
    oldlist = old.split(".")
 | 
			
		||||
    newlist = new.split(".")
 | 
			
		||||
    for i in range(len(oldlist)):
 | 
			
		||||
        if int(newlist[i]) > int(oldlist[i]):
 | 
			
		||||
            return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def sendtosuperuser(message):
 | 
			
		||||
    # 将消息发送给superuser
 | 
			
		||||
    from nonebot import get_bot, get_driver
 | 
			
		||||
    import asyncio
 | 
			
		||||
    superusers = get_driver().config.superusers
 | 
			
		||||
    bot = get_bot()
 | 
			
		||||
    for superuser in superusers:
 | 
			
		||||
        await bot.call_api('send_msg', **{
 | 
			
		||||
            'message': message,
 | 
			
		||||
            'user_id': superuser,
 | 
			
		||||
        })
 | 
			
		||||
        await asyncio.sleep(5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def png2jpg(raw: bytes):
 | 
			
		||||
    raw:BytesIO = BytesIO(base64.b64decode(raw))
 | 
			
		||||
    img_PIL = Image.open(raw).convert("RGB")
 | 
			
		||||
    image_new = BytesIO()
 | 
			
		||||
    img_PIL.save(image_new, format="JPEG", quality=95)
 | 
			
		||||
    image_new=image_new.getvalue()
 | 
			
		||||
    return image_new
 | 
			
		||||
							
								
								
									
										20
									
								
								utils/data.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								utils/data.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
# 基础优化tag
 | 
			
		||||
basetag = "masterpiece, best quality,"
 | 
			
		||||
 | 
			
		||||
# 基础排除tag
 | 
			
		||||
lowQuality = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, pubic hair,long neck,blurry"
 | 
			
		||||
 | 
			
		||||
# 屏蔽词
 | 
			
		||||
htags = "nsfw|nude|naked|nipple|blood|censored|vagina|gag|gokkun|hairjob|tentacle|oral|fellatio|areolae|lactation|paizuri|piercing|sex|footjob|masturbation|hips|penis|testicles|ejaculation|cum|tamakeri|pussy|pubic|clitoris|mons|cameltoe|grinding|crotch|cervix|cunnilingus|insertion|penetration|fisting|fingering|peeing|ass|buttjob|spanked|anus|anal|anilingus|enema|x-ray|wakamezake|humiliation|tally|futa|incest|twincest|pegging|femdom|ganguro|bestiality|gangbang|3P|tribadism|molestation|voyeurism|exhibitionism|rape|spitroast|cock|69|doggystyle|missionary|virgin|shibari|bondage|bdsm|rope|pillory|stocks|bound|hogtie|frogtie|suspension|anal|dildo|vibrator|hitachi|nyotaimori|vore|amputee|transformation|bloody"
 | 
			
		||||
 | 
			
		||||
shapemap = {
 | 
			
		||||
    "square": [640, 640],
 | 
			
		||||
    "s": [640, 640],
 | 
			
		||||
    "方": [640, 640],
 | 
			
		||||
    "portrait": [512, 768],
 | 
			
		||||
    "p": [512, 768],
 | 
			
		||||
    "高": [512, 768],
 | 
			
		||||
    "landscape": [768, 512],
 | 
			
		||||
    "l": [768, 512],
 | 
			
		||||
    "宽": [768, 512]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										38
									
								
								utils/prepocess.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								utils/prepocess.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
import re
 | 
			
		||||
from ..extension.translation import translate
 | 
			
		||||
 | 
			
		||||
escape_table = {
 | 
			
		||||
    '[': '[',
 | 
			
		||||
    ']': ']'
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async def prepocess_tags(tags: list[str]):
 | 
			
		||||
    tags: str = "".join([i+" " for i in tags if isinstance(i, str)])
 | 
			
		||||
    # 去除CQ码
 | 
			
		||||
    tags = re.sub("\[CQ[^\s]*?]", "", tags)
 | 
			
		||||
    # 检测中文
 | 
			
		||||
    taglist = tags.split(",")
 | 
			
		||||
    tagzh = ""
 | 
			
		||||
    tags_ = ""
 | 
			
		||||
    for i in taglist:
 | 
			
		||||
        if re.search('[\u4e00-\u9fa5]', tags):
 | 
			
		||||
            tagzh += f"{i},"
 | 
			
		||||
        else:
 | 
			
		||||
            tags_ += f"{i},"
 | 
			
		||||
    if tagzh:
 | 
			
		||||
        tags_en = await translate(tagzh, "en")
 | 
			
		||||
        if tags_en == tagzh:
 | 
			
		||||
            return ""
 | 
			
		||||
        else:
 | 
			
		||||
            tags_ += tags_en
 | 
			
		||||
    return await fix_char_escape(tags_)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def combine_multi_args(args: list[str]):
 | 
			
		||||
    return ' '.join(args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def fix_char_escape(tags: str):
 | 
			
		||||
    for escape, raw in escape_table.items():
 | 
			
		||||
        tags = tags.replace(escape, raw)
 | 
			
		||||
    return tags
 | 
			
		||||
							
								
								
									
										17
									
								
								utils/save.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								utils/save.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
from ..config import config
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import hashlib
 | 
			
		||||
import aiofiles
 | 
			
		||||
path = Path("data/novelai/output").resolve()
 | 
			
		||||
async def save_img(fifo, img_bytes: bytes, extra: str = "unknown"):
 | 
			
		||||
    # 存储图片
 | 
			
		||||
    if config.novelai_save:
 | 
			
		||||
        path_ = path / extra
 | 
			
		||||
        path_.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        hash = hashlib.md5(img_bytes).hexdigest()
 | 
			
		||||
        file = (path_ / hash).resolve()
 | 
			
		||||
        async with aiofiles.open(str(file) + ".jpg", "wb") as f:
 | 
			
		||||
            await f.write(img_bytes)
 | 
			
		||||
        if config.novelai_save==2:
 | 
			
		||||
            async with aiofiles.open(str(file) + ".txt", "w") as f:
 | 
			
		||||
                await f.write(repr(fifo))
 | 
			
		||||
		Reference in New Issue
	
	Block a user