diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index bfa7701b1..0241d7586 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -28,6 +28,8 @@ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, + DDIMScheduler, + DDPMScheduler, StableDiffusionControlNetPAGPipeline, StableDiffusionXLControlNetPAGPipeline, FluxControlNetPipeline, @@ -39,6 +41,21 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetUnionImg2ImgPipeline, StableDiffusionXLControlNetPAGImg2ImgPipeline, + StableDiffusionPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLKDiffusionPipeline, + StableDiffusion3Pipeline, + LatentConsistencyModelPipeline, + LatentConsistencyModelImg2ImgPipeline, + StableDiffusionControlNetXSPipeline, + StableDiffusionXLControlNetXSPipeline, + LEditsPPPipelineStableDiffusion, + LEditsPPPipelineStableDiffusionXL, + PIAPipeline, FluxControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline, @@ -68,6 +85,26 @@ "LMSDiscreteScheduler": LMSDiscreteScheduler, "EulerAncestralDiscreteScheduler": EulerAncestralDiscreteScheduler, "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler, + "DDIMScheduler": DDIMScheduler, + "DDPMScheduler": DDPMScheduler, +} + +SINGLE_FILE_MAP = { + "StableDiffusionPipeline": StableDiffusionPipeline, + "StableDiffusionImg2ImgPipeline": StableDiffusionImg2ImgPipeline, + "StableDiffusionInpaintPipeline": StableDiffusionInpaintPipeline, + "StableDiffusionXLImg2ImgPipeline": StableDiffusionXLImg2ImgPipeline, + "StableDiffusionXLInpaintPipeline": StableDiffusionXLInpaintPipeline, + "StableDiffusionXLInstructPix2PixPipeline": StableDiffusionXLInstructPix2PixPipeline, + "StableDiffusionXLKDiffusionPipeline": StableDiffusionXLKDiffusionPipeline, + "StableDiffusion3Pipeline": StableDiffusion3Pipeline, + "LatentConsistencyModelPipeline": LatentConsistencyModelPipeline, + "LatentConsistencyModelImg2ImgPipeline": LatentConsistencyModelImg2ImgPipeline, + "StableDiffusionControlNetXSPipeline": StableDiffusionControlNetXSPipeline, + "StableDiffusionXLControlNetXSPipeline": StableDiffusionXLControlNetXSPipeline, + "LEditsPPPipelineStableDiffusion": LEditsPPPipelineStableDiffusion, + "LEditsPPPipelineStableDiffusionXL": LEditsPPPipelineStableDiffusionXL, + "PIAPipeline": PIAPipeline, } @@ -96,6 +133,13 @@ def latents_to_rgb(latents): return Image.fromarray(image_array) +def is_single_file_model(model_path): + """Check if the model is a single-file format (.safetensors or .ckpt)""" + if isinstance(model_path, str): + return model_path.endswith((".safetensors", ".ckpt", ".pt")) + return False + + def create_decode_callback(output_dir): """Create a callback function to decode and save latents at each step""" @@ -374,7 +418,6 @@ def load_pipeline_with_sharding( # Handle 'auto' device safely by falling back to cuda:0 for controlnet safe_device = "cuda:0" if device == "auto" else device - controlnet = controlnet_class.from_pretrained( controlnet_id, torch_dtype=torch.float16 if safe_device != "cpu" else torch.float32, @@ -598,6 +641,28 @@ def load_pipeline_with_device_map( torch.cuda.empty_cache() # Load appropriate pipeline + is_single_file = is_single_file_model(model_path) + if is_single_file: + if is_inpainting: + architecture = "StableDiffusionInpaintPipeline" + elif is_img2img: + architecture = "StableDiffusionImg2ImgPipeline" + else: + architecture = "StableDiffusionPipeline" + single_pipeline_class = SINGLE_FILE_MAP.get(architecture) + if not single_pipeline_class: + raise ValueError(f"Model architecture '{architecture}' not supported for single-file models") + + if single_pipeline_class: + pipe = single_pipeline_class.from_single_file( + model_path, + torch_dtype=torch.float16 if device != "cpu" else torch.float32, + ) + if device != "auto": + pipe = pipe.to(device) + return pipe + else: + print(f"Warning: No single-file pipeline class found for architecture '{architecture}'") if is_controlnet: CONTROLNET_PIPELINE_MAP = { "StableDiffusionPipeline": StableDiffusionControlNetPipeline,