Skip to content

Commit

Permalink
patch hf hub (modelscope#987)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Sep 14, 2024
1 parent 51b33ce commit 4c518db
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 3 deletions.
20 changes: 20 additions & 0 deletions modelscope/hub/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,26 @@ def get_model_files(self,
files.append(file)
return files

def file_exists(
self,
repo_id: str,
filename: str,
*,
revision: Optional[str] = None,
):
"""Get if the specified file exists
Args:
repo_id (`str`): The repo id to use
filename (`str`): The queried filename
revision (`Optional[str]`): The repo revision
Returns:
The query result in bool value
"""
files = self.get_model_files(repo_id, revision=revision)
files = [file['Name'] for file in files]
return filename in files

def create_dataset(self,
dataset_name: str,
namespace: str,
Expand Down
161 changes: 160 additions & 1 deletion modelscope/utils/hf_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
import os
from pathlib import Path
from types import MethodType
from typing import Dict, Literal, Optional, Union

from transformers import AutoConfig as AutoConfigHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
Expand All @@ -14,10 +18,12 @@
from transformers import BatchFeature as BatchFeatureHF
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)

from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from .logger import get_logger

try:
from transformers import GPTQConfig as GPTQConfigHF
Expand All @@ -26,6 +32,8 @@
GPTQConfigHF = None
AwqConfigHF = None

logger = get_logger()


def user_agent(invoked_by=None):
if invoked_by is None:
Expand All @@ -34,6 +42,157 @@ def user_agent(invoked_by=None):
return uagent


def _try_login(token: Optional[str] = None):
from modelscope.hub.api import HubApi
api = HubApi()
if token is None:
token = os.environ.get('MODELSCOPE_API_TOKEN')
if token:
api.login(token)


def _file_exists(
self,
repo_id: str,
filename: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
):
"""Patch huggingface_hub.file_exists"""
if repo_type is not None:
logger.warning(
'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
)
_try_login(token)
from modelscope.hub.api import HubApi
api = HubApi()
return api.file_exists(repo_id, filename, revision=revision)


def _file_download(repo_id: str,
filename: str,
*,
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
token: Union[bool, str, None] = None,
local_files_only: bool = False,
**kwargs):
"""Patch huggingface_hub.hf_hub_download"""
if len(kwargs) > 0:
logger.warning(
'The passed in library_name,library_version,user_agent,force_download,proxies'
'etag_timeout,headers,endpoint '
'will not be used in modelscope.')
assert repo_type in (
None, 'model',
'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
if repo_type in (None, 'model'):
from modelscope.hub.file_download import model_file_download as file_download
else:
from modelscope.hub.file_download import dataset_file_download as file_download
_try_login(token)
return file_download(
repo_id,
file_path=os.path.join(subfolder, filename) if subfolder else filename,
cache_dir=cache_dir,
local_dir=local_dir,
local_files_only=local_files_only,
revision=revision)


def _patch_pretrained_class():

def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
**kwargs):
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return model_dir

def patch_tokenizer_base():
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)

PreTrainedTokenizerBase.from_pretrained = from_pretrained

def patch_config_base():
""" Monkey patch PretrainedConfig.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PretrainedConfig.from_pretrained.__func__
ori_get_config_dict = PretrainedConfig.get_config_dict.__func__

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)

@classmethod
def get_config_dict(cls, pretrained_model_name_or_path, **kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_get_config_dict(cls, model_dir, **kwargs)

PretrainedConfig.get_config_dict = get_config_dict

def patch_model_base():
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
model_dir = get_model_dir(pretrained_model_name_or_path, None,
**kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)

PreTrainedModel.from_pretrained = from_pretrained

patch_tokenizer_base()
patch_config_base()
patch_model_base()


def patch_hub():
"""Patch hf hub, which to make users can download models from modelscope to speed up.
"""
import huggingface_hub
from huggingface_hub import hf_api
from huggingface_hub.hf_api import api

huggingface_hub.hf_hub_download = _file_download
huggingface_hub.file_download.hf_hub_download = _file_download

hf_api.file_exists = MethodType(_file_exists, api)
huggingface_hub.file_exists = hf_api.file_exists
huggingface_hub.hf_api.file_exists = hf_api.file_exists

_patch_pretrained_class()


def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
Expand Down
2 changes: 1 addition & 1 deletion requirements/datasets.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.18.0
datasets>=2.18.0,<3.0.0
einops
oss2
Pillow
Expand Down
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.18.0
datasets>=2.18.0,<3.0.0
einops
oss2
Pillow
Expand Down

0 comments on commit 4c518db

Please sign in to comment.