forked from FlintyLemming/ComfyUI-AnyText
214 lines
9.6 KiB
Python
214 lines
9.6 KiB
Python
|
import os
|
|||
|
import folder_paths
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
import time
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
|||
|
comfyui_models_dir = folder_paths.models_dir
|
|||
|
comfy_temp_dir = folder_paths.get_temp_directory()
|
|||
|
temp_txt_path = os.path.join(comfy_temp_dir, "AnyText_temp.txt")
|
|||
|
|
|||
|
class AnyText_loader:
|
|||
|
@classmethod
|
|||
|
def INPUT_TYPES(cls):
|
|||
|
font_list = os.listdir(os.path.join(comfyui_models_dir, "fonts"))
|
|||
|
checkpoints_list = folder_paths.get_filename_list("checkpoints")
|
|||
|
clip_list = os.listdir(os.path.join(comfyui_models_dir, "clip"))
|
|||
|
font_list.insert(0, "Auto_DownLoad")
|
|||
|
checkpoints_list.insert(0, "Auto_DownLoad")
|
|||
|
clip_list.insert(0, "Auto_DownLoad")
|
|||
|
|
|||
|
return {
|
|||
|
"required": {
|
|||
|
"font": (font_list, ),
|
|||
|
"ckpt_name": (checkpoints_list, ),
|
|||
|
"clip": (clip_list, ),
|
|||
|
"translator": (["utrobinmv/t5_translate_en_ru_zh_small_1024", "damo/nlp_csanmt_translation_zh2en"],{"default": "utrobinmv/t5_translate_en_ru_zh_small_1024"}),
|
|||
|
# "show_debug": ("BOOLEAN", {"default": False}),
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
RETURN_TYPES = ("AnyText_Loader", )
|
|||
|
RETURN_NAMES = ("AnyText_Loader", )
|
|||
|
FUNCTION = "AnyText_loader_fn"
|
|||
|
CATEGORY = "ExtraModels/AnyText"
|
|||
|
TITLE = "AnyText Loader"
|
|||
|
|
|||
|
def AnyText_loader_fn(self,
|
|||
|
font,
|
|||
|
ckpt_name,
|
|||
|
clip,
|
|||
|
translator,
|
|||
|
# show_debug
|
|||
|
):
|
|||
|
font_path = os.path.join(comfyui_models_dir, "fonts", font)
|
|||
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
|||
|
cfg_path = os.path.join(current_directory, 'models_yaml', 'anytext_sd15.yaml')
|
|||
|
if clip != 'Auto_DownLoad':
|
|||
|
clip_path = os.path.join(comfyui_models_dir, "clip", clip)
|
|||
|
else:
|
|||
|
clip_path = clip
|
|||
|
if translator != 'Auto_DownLoad':
|
|||
|
translator_path = os.path.join(comfyui_models_dir, "prompt_generator", translator)
|
|||
|
else:
|
|||
|
translator_path = translator
|
|||
|
|
|||
|
#将输入参数合并到一个参数里面传递到.nodes
|
|||
|
loader = (font_path + "|" + str(ckpt_path) + "|" + clip_path + "|" + translator_path + "|" + cfg_path)
|
|||
|
|
|||
|
# if show_debug == True:
|
|||
|
# print(f'\033[93mloader(合并后的4个输入参数,传递给nodes): {loader} \033[0m\n \
|
|||
|
# \033[93mfont_path(字体): {font_path} \033[0m\n \
|
|||
|
# \033[93mckpt_path(AnyText模型): {ckpt_path} \033[0m\n \
|
|||
|
# \033[93mclip_path(clip模型): {clip_path} \033[0m\n \
|
|||
|
# \033[93mtranslator_path(翻译模型): {translator_path} \033[0m\n \
|
|||
|
# \033[93myaml_file(yaml配置文件): {cfg_path} \033[0m\n')
|
|||
|
return (loader, )
|
|||
|
|
|||
|
class AnyText_translator:
|
|||
|
@classmethod
|
|||
|
def INPUT_TYPES(cls):
|
|||
|
return {
|
|||
|
"required": {
|
|||
|
"model": (["utrobinmv/t5_translate_en_ru_zh_small_1024", "damo/nlp_csanmt_translation_zh2en"],{"default": "utrobinmv/t5_translate_en_ru_zh_small_1024"}),
|
|||
|
"prompt": ("STRING", {"default": "这里是单批次翻译文本输入。\n声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。\n尽量输入单句文本,如果是多句长文本建议人工分句,否则可能出现漏译或未译等情况!!!\n使用换行,效果可能更佳。", "multiline": True}),
|
|||
|
"Batch_prompt": ("STRING", {"default": "这里是多批次翻译文本输入,使用换行进行分割。\n天上掉馅饼啦,快去看超人!!!\n飞流直下三千尺,疑似银河落九天。\n启用Batch_Newline表示输出的翻译会按换行输入进行二次换行,否则是用空格合并起来的整篇文本。", "multiline": True}),
|
|||
|
"t5_Target_Language": (["en", "zh", "ru", ],{"default": "en"}),
|
|||
|
"if_Batch": ("BOOLEAN", {"default": False}),
|
|||
|
"Batch_Newline" :("BOOLEAN", {"default": True}),
|
|||
|
"device": (["auto", "cuda", "cpu", "mps", "xpu"],{"default": "auto"}),
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
RETURN_TYPES = ("STRING",)
|
|||
|
RETURN_NAMES = ("text",)
|
|||
|
CATEGORY = "ExtraModels/AnyText"
|
|||
|
FUNCTION = "AnyText_translator"
|
|||
|
TITLE = "AnyText Translator"
|
|||
|
|
|||
|
def AnyText_translator(self, prompt, model, Batch_prompt, if_Batch, device, Batch_Newline, t5_Target_Language):
|
|||
|
device = get_device_by_name(device)
|
|||
|
# 使用换行(\n)作为分隔符
|
|||
|
Batch_prompt = Batch_prompt.split("\n")
|
|||
|
input_sequence = prompt
|
|||
|
if model == 'damo/nlp_csanmt_translation_zh2en':
|
|||
|
sttime = time.time()
|
|||
|
if if_Batch == True:
|
|||
|
input_sequence = Batch_prompt
|
|||
|
# 用特定的连接符<SENT_SPLIT>,将多个句子进行串联
|
|||
|
input_sequence = '<SENT_SPLIT>'.join(input_sequence)
|
|||
|
if os.access(os.path.join(comfyui_models_dir, "prompt_generator", "nlp_csanmt_translation_zh2en", "tf_ckpts", "ckpt-0.data-00000-of-00001"), os.F_OK):
|
|||
|
zh2en_path = os.path.join(comfyui_models_dir, 'prompt_generator', 'nlp_csanmt_translation_zh2en')
|
|||
|
else:
|
|||
|
zh2en_path = "damo/nlp_csanmt_translation_zh2en"
|
|||
|
|
|||
|
if not is_module_imported('pipeline'):
|
|||
|
from modelscope.pipelines import pipeline
|
|||
|
if not is_module_imported('Tasks'):
|
|||
|
from modelscope.utils.constant import Tasks
|
|||
|
if device == 'cuda':
|
|||
|
pipeline_ins = pipeline(task=Tasks.translation, model=zh2en_path, device='gpu')
|
|||
|
outputs = pipeline_ins(input=input_sequence)
|
|||
|
if if_Batch == True:
|
|||
|
results = outputs['translation'].split('<SENT_SPLIT>')
|
|||
|
if Batch_Newline == True:
|
|||
|
results = '\n\n'.join(results)
|
|||
|
else:
|
|||
|
results = ' '.join(results)
|
|||
|
else:
|
|||
|
results = outputs['translation']
|
|||
|
endtime = time.time()
|
|||
|
print("\033[93mTime for translating(翻译耗时): ", endtime - sttime, "\033[0m")
|
|||
|
del pipeline_ins
|
|||
|
if torch.cuda.is_available():
|
|||
|
torch.cuda.empty_cache()
|
|||
|
else:
|
|||
|
if if_Batch == True:
|
|||
|
input_sequence = Batch_prompt
|
|||
|
# 用特定的连接符<SENT_SPLIT>,将多个句子进行串联
|
|||
|
input_sequence = '|'.join(input_sequence)
|
|||
|
self.zh2en_path = os.path.join(folder_paths.models_dir, "prompt_generator", "models--utrobinmv--t5_translate_en_ru_zh_small_1024")
|
|||
|
if not os.access(os.path.join(self.zh2en_path, "model.safetensors"), os.F_OK):
|
|||
|
self.zh2en_path = "utrobinmv/t5_translate_en_ru_zh_small_1024"
|
|||
|
outputs = t5_translate_en_ru_zh(t5_Target_Language, input_sequence, self.zh2en_path, device)[0]
|
|||
|
if if_Batch == True:
|
|||
|
results = outputs.split('| ')
|
|||
|
if Batch_Newline == True:
|
|||
|
results = '\n\n'.join(results)
|
|||
|
else:
|
|||
|
results = ' '.join(results)
|
|||
|
else:
|
|||
|
results = outputs
|
|||
|
|
|||
|
with open(temp_txt_path, "w", encoding="UTF-8") as text_file:
|
|||
|
text_file.write(results)
|
|||
|
return (results, )
|
|||
|
|
|||
|
def is_module_imported(module_name):
|
|||
|
try:
|
|||
|
__import__(module_name)
|
|||
|
except ImportError:
|
|||
|
return False
|
|||
|
else:
|
|||
|
return True
|
|||
|
|
|||
|
def pil2tensor(image):
|
|||
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|||
|
|
|||
|
def is_folder_exist(folder_path):
|
|||
|
result = os.path.exists(folder_path)
|
|||
|
return result
|
|||
|
|
|||
|
def get_device_by_name(device):
|
|||
|
if device == 'auto':
|
|||
|
try:
|
|||
|
device = "cpu"
|
|||
|
if torch.cuda.is_available():
|
|||
|
device = "cuda"
|
|||
|
elif torch.backends.mps.is_available():
|
|||
|
device = "mps"
|
|||
|
elif torch.xpu.is_available():
|
|||
|
device = "xpu"
|
|||
|
except:
|
|||
|
raise AttributeError("What's your device(到底用什么设备跑的)?")
|
|||
|
print("\033[93mUse Device(使用设备):", device, "\033[0m")
|
|||
|
return device
|
|||
|
|
|||
|
# Node class and display name mappings
|
|||
|
NODE_CLASS_MAPPINGS = {
|
|||
|
"AnyText_loader": AnyText_loader,
|
|||
|
"AnyText_translator": AnyText_translator,
|
|||
|
}
|
|||
|
|
|||
|
def t5_translate_en_ru_zh(Target_Language, prompt, model_path, device):
|
|||
|
|
|||
|
# prefix = 'translate to en: '
|
|||
|
sttime = time.time()
|
|||
|
if not is_module_imported('T5ForConditionalGeneration'):
|
|||
|
from transformers import T5ForConditionalGeneration
|
|||
|
if not is_module_imported('T5Tokenizer'):
|
|||
|
from transformers import T5Tokenizer
|
|||
|
model = T5ForConditionalGeneration.from_pretrained(model_path,)
|
|||
|
tokenizer = T5Tokenizer.from_pretrained(model_path)
|
|||
|
if Target_Language == 'zh':
|
|||
|
prefix = 'translate to zh: '
|
|||
|
elif Target_Language == 'en':
|
|||
|
prefix = 'translate to en: '
|
|||
|
else:
|
|||
|
prefix = 'translate to ru: '
|
|||
|
src_text = prefix + prompt
|
|||
|
input_ids = tokenizer(src_text, return_tensors="pt")
|
|||
|
generated_tokens = model.generate(**input_ids).to(device, torch.float32)
|
|||
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|||
|
model.to('cpu')
|
|||
|
endtime = time.time()
|
|||
|
print("\033[93mTime for translating(翻译耗时): ", endtime - sttime, "\033[0m")
|
|||
|
return result
|
|||
|
|
|||
|
def comfy_tensor_Image2np_Image(self,comfy_tensor_image):
|
|||
|
comfyimage = comfy_tensor_image.numpy()[0] * 255
|
|||
|
image_np = comfyimage.astype(np.uint8)
|
|||
|
image = Image.fromarray(image_np)
|
|||
|
return image
|