Source code for unimol_tools.weights.weighthub

import os

from ..utils import logger

try:
    from huggingface_hub import snapshot_download
except:
    huggingface_hub_installed = False

    def snapshot_download(*args, **kwargs):
        raise ImportError(
            'huggingface_hub is not installed. If weights are not avaliable, please install it by running: pip install huggingface_hub. Otherwise, please download the weights manually from https://huggingface.co/dptech/Uni-Mol-Models'
        )


WEIGHT_DIR = os.environ.get(
    'UNIMOL_WEIGHT_DIR', os.path.dirname(os.path.abspath(__file__))
)

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"  # use mirror to download weights


[docs] def log_weights_dir(): """ Logs the directory where the weights are stored. """ if 'UNIMOL_WEIGHT_DIR' in os.environ: logger.warning( f'Using custom weight directory from UNIMOL_WEIGHT_DIR: {WEIGHT_DIR}' ) else: logger.info(f'Weights will be downloaded to default directory: {WEIGHT_DIR}')
[docs] def weight_download(pretrain, save_path, local_dir_use_symlinks=True): """ Downloads the specified pretrained model weights. :param pretrain: (str), The name of the pretrained model to download. :param save_path: (str), The directory where the weights should be saved. :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. """ log_weights_dir() if os.path.exists(os.path.join(save_path, pretrain)): logger.info(f'{pretrain} exists in {save_path}') return logger.info(f'Downloading {pretrain}') snapshot_download( repo_id="dptech/Uni-Mol-Models", local_dir=save_path, allow_patterns=pretrain, local_dir_use_symlinks=local_dir_use_symlinks, # max_workers=8 )
[docs] def weight_download_v2(pretrain, save_path, local_dir_use_symlinks=True): """ Downloads the specified pretrained model weights. :param pretrain: (str), The name of the pretrained model to download. :param save_path: (str), The directory where the weights should be saved. :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. """ log_weights_dir() if os.path.exists(os.path.join(save_path, pretrain)): logger.info(f'{pretrain} exists in {save_path}') return logger.info(f'Downloading {pretrain}') snapshot_download( repo_id="dptech/Uni-Mol2", local_dir=save_path, allow_patterns=pretrain, local_dir_use_symlinks=local_dir_use_symlinks, # max_workers=8 )
# Download all the weights when this script is run
[docs] def download_all_weights(local_dir_use_symlinks=False): """ Downloads all available pretrained model weights to the WEIGHT_DIR. :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to False. """ log_weights_dir() logger.info(f'Downloading all weights to {WEIGHT_DIR}') snapshot_download( repo_id="dptech/Uni-Mol-Models", local_dir=WEIGHT_DIR, allow_patterns='*', local_dir_use_symlinks=local_dir_use_symlinks, # max_workers=8 )
if '__main__' == __name__: download_all_weights()