diff --git a/torch/__init__.py b/torch/__init__.py index c45e2d8f0de33ed..bd0bfa59d5919c2 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -142,6 +142,24 @@ kernel32.SetErrorMode(prev_error_mode) +def _preload_cuda_deps(): + """ Preloads cudnn/cublas deps if they could not be found otherwise """ + # Should only be called on Linux if default path resolution have failed + assert platform.system() == 'Linux', 'Should only be called on Linux' + for path in sys.path: + nvidia_path = os.path.join(path, 'nvidia') + if not os.path.exists(nvidia_path): + continue + cublas_path = os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.11') + cudnn_path = os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.8') + if not os.path.exists(cublas_path) or not os.path.exists(cudnn_path): + continue + break + + ctypes.CDLL(cublas_path) + ctypes.CDLL(cudnn_path) + + # See Note [Global dependencies] def _load_global_deps(): if platform.system() == 'Windows' or sys.executable == 'torch_deploy': @@ -151,7 +169,15 @@ def _load_global_deps(): here = os.path.abspath(__file__) lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + try: + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + except OSError as err: + # Can only happen of wheel with cublas as PYPI deps + # As PyTorch is not purelib, but nvidia-cublas-cu11 is + if 'libcublas.so.11' not in err.args[0]: + raise err + _preload_cuda_deps() + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \