diff --git a/examples/PyTorch/Inference/hf_diffusers/README.md b/examples/PyTorch/Inference/hf_diffusers/README.md new file mode 100644 index 00000000000..f0cc5d73d53 --- /dev/null +++ b/examples/PyTorch/Inference/hf_diffusers/README.md @@ -0,0 +1,32 @@ +# Accelerate Inference of Stable Diffusion using BladeDISC + +*(under development)* +BladeDISC can compile PyTorch models in Stable Diffusion pipeline to improve the inference speed. +A general workflow is like: export model and call BladeDISC to compile, then wrap optimized model +into original pipeline. +To further simplify the optimization workflow, we provide a adapter for Huggingface Diffusers library. + +## Usage + +### Use Pipeline Adapter + +```python +from blade_adapter import BladeStableDiffusionPipeline + +# use adapter to load pipe and optimize models: +pipe = BladeStableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5') + +# use optimized pipeline like original one: +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] + +# save and load optimized pipeline (to avoid run compilation from original models every time): +pipe.saved_pretrained('cached/dir/stable-diffusion-v1-5-blade-opt') +pipe = BladeStableDiffusionPipeline.from_pretrained('cached/dir/stable-diffusion-v1-5-blade-opt') + +``` + + +### Use Model Adapter + +*(TBD)* \ No newline at end of file diff --git a/examples/PyTorch/Inference/hf_diffusers/blade_adapter.py b/examples/PyTorch/Inference/hf_diffusers/blade_adapter.py new file mode 100644 index 00000000000..655ebe15c32 --- /dev/null +++ b/examples/PyTorch/Inference/hf_diffusers/blade_adapter.py @@ -0,0 +1,137 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from os import PathLike +from tempfile import TemporaryDirectory +from typing import Union + +import torch +from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.pipelines.pipeline_utils import LOADABLE_CLASSES +from transformers import CLIPTextModel, PreTrainedModel + +LOGGER = logging.getLogger(__name__) + + +class OptModel: + original_class = None + model_file = 'model.jit' + + def __init__(self, opt_model: torch.jit.ScriptModule): + self.opt_model = opt_model + + def save_pretrained(self, save_directory: Union[str, PathLike], **kwargs): + torch.jit.save(self.opt_model, os.path.join( + save_directory, self.model_file)) + + @classmethod + def gen_example_input(cls): + raise NotImplementedError() + + @classmethod + def from_pretrained(cls, cached_dir: Union[str, PathLike], **kwargs): + if os.path.isfile(os.path.join(cached_dir, cls.model_file)): + return cls.from_opt(cached_dir) + else: + return cls.from_original(cached_dir, **kwargs) + + @classmethod + def from_opt(cls, cached_dir: Union[str, PathLike]): + opt_model = torch.jit.load(os.path.join(cached_dir, cls.model_file)) + return cls(opt_model) + + @classmethod + def from_original(cls, cached_dir: Union[str, PathLike], **kwargs): + if issubclass(cls.original_class, PreTrainedModel): + kwargs['torchscript'] = True + # TODO(litan.ls): use load method from LOADABLE_CLASSES + original = cls.original_class.from_pretrained(cached_dir, **kwargs) + + example_inputs = cls.gen_example_input() + traced = torch.jit.trace(original.eval(), example_inputs) + # TODO(litan.ls): call blade optimize + return cls(traced) + + +class BladeCLIPTextModel(OptModel): + original_class = CLIPTextModel + + @classmethod + def gen_example_input(cls): + return torch.randint(1, 999, (1, 10), dtype=torch.int64) + + def __call__(self, *args): + # TODO(litan.ls): wrapper output as original model + return self.opt_model(*args) + + +class BladeUNet2DConditionModel(OptModel): + original_class = UNet2DConditionModel + + @classmethod + def gen_example_input(cls): + # TODO(litan.ls): support gen input from pipeline config + return ( + torch.randn((1, 4, 64, 64), dtype=torch.half), + torch.tensor(2, dtype=torch.int64), + torch.randn((1, 10, 768), dtype=torch.half), + ) + + def forward(self, *args): + return UNet2DConditionOutput(self.opt_model(*args)) + + +# TODO(litan.ls): support more models +_MODEL_MAPPING = { + 'text_encoder': (['transformers', 'CLIPTextModel'], ['blade_adapter', 'BladeCLIPTextModel']), + 'unet': (['diffusers', 'UNet2DConditionModel'], ['blade_adapter', 'BladeUNet2DConditionModel']), +} + + +class BladeStableDiffusionPipeline(StableDiffusionPipeline): + @classmethod + def overwrite_config(cls, input_cached_dir: Union[str, PathLike], output_cached_dir: Union[str, PathLike]): + config_dict = cls.load_config(input_cached_dir) + for k, (src_model, dst_model) in _MODEL_MAPPING.items(): + if k not in config_dict: + LOGGER.warn(f'{k} model not found in pipeline config.') + elif config_dict[k] != src_model: + LOGGER.warn(f'Cannot overwrite {k} model type {src_model}') + else: + config_dict[k] = dst_model + for dirpath, _, filenames in os.walk(input_cached_dir): + relpath = os.path.relpath(dirpath, input_cached_dir) + os.makedirs(os.path.join( + output_cached_dir, relpath), exist_ok=True) + for f in filenames: + os.symlink(os.path.abspath(os.path.join(dirpath, f)), + os.path.join(output_cached_dir, relpath, f)) + config_path = os.path.join(output_cached_dir, cls.config_name) + os.unlink(config_path) + with open(config_path, 'w') as config_file: + config_file.write(json.dumps(config_dict, indent=2)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, PathLike]): + if not os.path.isdir(pretrained_model_name_or_path): + raise NotImplementedError('Support snapshot download') + else: + cached_dir = pretrained_model_name_or_path + + with TemporaryDirectory() as tmpdir: + cls.overwrite_config(cached_dir, tmpdir) + LOADABLE_CLASSES['blade_adapter'] = { + "OptModel": ["save_pretrained", "from_pretrained"]} + return super().from_pretrained(tmpdir) diff --git a/examples/PyTorch/Inference/hf_diffusers/tests/test_model.py b/examples/PyTorch/Inference/hf_diffusers/tests/test_model.py new file mode 100644 index 00000000000..3d922065072 --- /dev/null +++ b/examples/PyTorch/Inference/hf_diffusers/tests/test_model.py @@ -0,0 +1,45 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from tempfile import TemporaryDirectory + +import torch +from blade_adapter import BladeCLIPTextModel +from transformers import CLIPTextModel + +CACHED_DIR = 'model_cache/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819' + + +class ModelTest(unittest.TestCase): + def test_text_encoder(self): + model_dir = os.path.join(CACHED_DIR, 'text_encoder') + opt_model = BladeCLIPTextModel.from_original(model_dir) + original_model = CLIPTextModel.from_pretrained(model_dir) + example_inputs = BladeCLIPTextModel.gen_example_input() + opt_output = opt_model(example_inputs) + golden_output = original_model(example_inputs) + torch.testing.assert_close(opt_output[0], golden_output[0]) + + with TemporaryDirectory() as tmpdir: + opt_model.save_pretrained(tmpdir) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, 'model.jit'))) + + opt_model_2 = BladeCLIPTextModel.from_opt(tmpdir) + opt_output_2 = opt_model_2(example_inputs) + torch.testing.assert_close(opt_output_2[0], golden_output[0]) + + # TODO(litan.ls): other model test + + +if __name__ == '__main__': + unittest.main() diff --git a/examples/PyTorch/Inference/hf_diffusers/tests/test_pipeline.py b/examples/PyTorch/Inference/hf_diffusers/tests/test_pipeline.py new file mode 100644 index 00000000000..faab6877331 --- /dev/null +++ b/examples/PyTorch/Inference/hf_diffusers/tests/test_pipeline.py @@ -0,0 +1,37 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from tempfile import TemporaryDirectory + +from blade_adapter import BladeStableDiffusionPipeline + +CACHED_DIR = 'model_cache/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819' +PIPE_ID = 'runwayml/stable-diffusion-v1-5' + + +class PipelineTest(unittest.TestCase): + def test_overwrite_config(self): + with TemporaryDirectory() as tmpdir: + BladeStableDiffusionPipeline.overwrite_config(CACHED_DIR, tmpdir) + new_config = BladeStableDiffusionPipeline.load_config(tmpdir) + self.assertEqual(new_config['text_encoder'], [ + 'blade_adapter', 'BladeCLIPTextModel']) + + def test_from_pretrained(self): + self.assertRaises(NotImplementedError, + BladeStableDiffusionPipeline.from_pretrained, PIPE_ID) + pipe = BladeStableDiffusionPipeline.from_pretrained(CACHED_DIR) + # TODO(litan.ls): compare pipeline output + + +if __name__ == '__main__': + unittest.main()