Skip to content
67 changes: 66 additions & 1 deletion api/transformerlab/plugins/image_diffusion/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
ControlNetModel,
DDIMScheduler,
DDPMScheduler,
StableDiffusionControlNetPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
FluxControlNetPipeline,
Expand All @@ -39,6 +41,21 @@
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetUnionImg2ImgPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLKDiffusionPipeline,
StableDiffusion3Pipeline,
LatentConsistencyModelPipeline,
LatentConsistencyModelImg2ImgPipeline,
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
PIAPipeline,
FluxControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionXLControlNetInpaintPipeline,
Expand Down Expand Up @@ -68,6 +85,26 @@
"LMSDiscreteScheduler": LMSDiscreteScheduler,
"EulerAncestralDiscreteScheduler": EulerAncestralDiscreteScheduler,
"DPMSolverMultistepScheduler": DPMSolverMultistepScheduler,
"DDIMScheduler": DDIMScheduler,
"DDPMScheduler": DDPMScheduler,
}

SINGLE_FILE_MAP = {
"StableDiffusionPipeline": StableDiffusionPipeline,
"StableDiffusionImg2ImgPipeline": StableDiffusionImg2ImgPipeline,
"StableDiffusionInpaintPipeline": StableDiffusionInpaintPipeline,
"StableDiffusionXLImg2ImgPipeline": StableDiffusionXLImg2ImgPipeline,
"StableDiffusionXLInpaintPipeline": StableDiffusionXLInpaintPipeline,
"StableDiffusionXLInstructPix2PixPipeline": StableDiffusionXLInstructPix2PixPipeline,
"StableDiffusionXLKDiffusionPipeline": StableDiffusionXLKDiffusionPipeline,
"StableDiffusion3Pipeline": StableDiffusion3Pipeline,
"LatentConsistencyModelPipeline": LatentConsistencyModelPipeline,
"LatentConsistencyModelImg2ImgPipeline": LatentConsistencyModelImg2ImgPipeline,
"StableDiffusionControlNetXSPipeline": StableDiffusionControlNetXSPipeline,
"StableDiffusionXLControlNetXSPipeline": StableDiffusionXLControlNetXSPipeline,
"LEditsPPPipelineStableDiffusion": LEditsPPPipelineStableDiffusion,
"LEditsPPPipelineStableDiffusionXL": LEditsPPPipelineStableDiffusionXL,
"PIAPipeline": PIAPipeline,
}


Expand Down Expand Up @@ -96,6 +133,13 @@ def latents_to_rgb(latents):
return Image.fromarray(image_array)


def is_single_file_model(model_path):
"""Check if the model is a single-file format (.safetensors or .ckpt)"""
if isinstance(model_path, str):
return model_path.endswith((".safetensors", ".ckpt", ".pt"))
return False


def create_decode_callback(output_dir):
"""Create a callback function to decode and save latents at each step"""

Expand Down Expand Up @@ -374,7 +418,6 @@ def load_pipeline_with_sharding(

# Handle 'auto' device safely by falling back to cuda:0 for controlnet
safe_device = "cuda:0" if device == "auto" else device

controlnet = controlnet_class.from_pretrained(
controlnet_id,
torch_dtype=torch.float16 if safe_device != "cpu" else torch.float32,
Expand Down Expand Up @@ -598,6 +641,28 @@ def load_pipeline_with_device_map(
torch.cuda.empty_cache()

# Load appropriate pipeline
is_single_file = is_single_file_model(model_path)
if is_single_file:
if is_inpainting:
architecture = "StableDiffusionInpaintPipeline"
elif is_img2img:
architecture = "StableDiffusionImg2ImgPipeline"
else:
architecture = "StableDiffusionPipeline"
single_pipeline_class = SINGLE_FILE_MAP.get(architecture)
if not single_pipeline_class:
raise ValueError(f"Model architecture '{architecture}' not supported for single-file models")

if single_pipeline_class:
pipe = single_pipeline_class.from_single_file(
model_path,
torch_dtype=torch.float16 if device != "cpu" else torch.float32,
)
if device != "auto":
pipe = pipe.to(device)
return pipe
else:
print(f"Warning: No single-file pipeline class found for architecture '{architecture}'")
if is_controlnet:
CONTROLNET_PIPELINE_MAP = {
"StableDiffusionPipeline": StableDiffusionControlNetPipeline,
Expand Down
Loading