Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import os
import re
from dataclasses import dataclass
from typing import NoReturn, TypedDict

Expand Down Expand Up @@ -62,6 +63,7 @@ class _BitcodeLibInfo(TypedDict):
name for name, info in _SUPPORTED_BITCODE_LIBS_INFO.items() if not IS_WINDOWS or info["available_on_windows"]
)
)
_SM_ARCH_PATTERN = re.compile(r"sm[0-9]+[a-z]?")


def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str], attachments: list[str]) -> None:
Expand All @@ -74,13 +76,24 @@ def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str]
attachments.append(f' Directory does not exist: "{dir_path}"')


def _filename_with_sm_arch(filename: str, sm_arch: str | None) -> str:
if sm_arch is None:
return filename

if not _SM_ARCH_PATTERN.fullmatch(sm_arch):
raise ValueError(f"Invalid sm_arch: {sm_arch!r} must match {_SM_ARCH_PATTERN.pattern!r}")

stem, ext = os.path.splitext(filename)
return f"{stem}_{sm_arch}{ext}"


class _FindBitcodeLib:
def __init__(self, name: str) -> None:
def __init__(self, name: str, sm_arch: str | None = None) -> None:
if name not in _SUPPORTED_BITCODE_LIBS_INFO: # Updated reference
raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}")
self.name: str = name
self.config: _BitcodeLibInfo = _SUPPORTED_BITCODE_LIBS_INFO[name] # Updated reference
self.filename: str = self.config["filename"]
self.filename: str = _filename_with_sm_arch(self.config["filename"], sm_arch)
self.rel_path: str = self.config["rel_path"]
self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"]
self.error_messages: list[str] = []
Expand Down Expand Up @@ -130,14 +143,25 @@ def raise_not_found_error(self) -> NoReturn:
raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}')


def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:
def locate_bitcode_lib(name: str, *, sm_arch: str | None = None) -> LocatedBitcodeLib:
"""Locate a bitcode library by name.

When ``sm_arch`` is not ``None``, locate the architecture-specific bitcode
filename with ``_{sm_arch}`` inserted before the ``.bc`` suffix.

Args:
name: Name of the supported bitcode library to locate.
sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or
``"sm90a"``. If not ``None``, it must match
``sm[0-9]+[a-z]?``.

Raises:
ValueError: If ``name`` is not a supported bitcode library.
ValueError: If ``name`` is not a supported bitcode library, or if
``sm_arch`` is not ``None`` and does not match
``sm[0-9]+[a-z]?``.
BitcodeLibNotFoundError: If the bitcode library cannot be found.
"""
finder = _FindBitcodeLib(name)
finder = _FindBitcodeLib(name, sm_arch)

abs_path = finder.try_site_packages()
if abs_path is not None:
Expand Down Expand Up @@ -170,11 +194,22 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:


@functools.cache
def find_bitcode_lib(name: str) -> str:
def find_bitcode_lib(name: str, *, sm_arch: str | None = None) -> str:
"""Find the absolute path to a bitcode library.

When ``sm_arch`` is not ``None``, find the architecture-specific bitcode
filename with ``_{sm_arch}`` inserted before the ``.bc`` suffix.

Args:
name: Name of the supported bitcode library to find.
sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or
``"sm90a"``. If not ``None``, it must match
``sm[0-9]+[a-z]?``.

Raises:
ValueError: If ``name`` is not a supported bitcode library.
ValueError: If ``name`` is not a supported bitcode library, or if
``sm_arch`` is not ``None`` and does not match
``sm[0-9]+[a-z]?``.
BitcodeLibNotFoundError: If the bitcode library cannot be found.
"""
return locate_bitcode_lib(name).abs_path
return locate_bitcode_lib(name, sm_arch=sm_arch).abs_path
159 changes: 149 additions & 10 deletions cuda_pathfinder/tests/test_find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ def _bitcode_lib_info(libname: str):
return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname]


def _bitcode_lib_filename(libname: str) -> str:
return _bitcode_lib_info(libname)["filename"]


@pytest.fixture
def clear_find_bitcode_lib_cache():
find_bitcode_lib_module.find_bitcode_lib.cache_clear()
Expand All @@ -36,9 +32,9 @@ def clear_find_bitcode_lib_cache():
get_cuda_path_or_home.cache_clear()


def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str:
def _make_bitcode_lib_file(dir_path: Path, filename: str) -> str:
dir_path.mkdir(parents=True, exist_ok=True)
file_path = dir_path / _bitcode_lib_filename(libname)
file_path = dir_path / filename
file_path.touch()
return str(file_path)

Expand Down Expand Up @@ -92,14 +88,16 @@ def test_locate_bitcode_lib(info_summary_append, libname):
@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS)
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname):
filename = _bitcode_lib_info(libname)["filename"]

site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, filename)

conda_prefix = tmp_path / "conda-prefix"
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), libname)
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), filename)

cuda_home = tmp_path / "cuda-home"
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), libname)
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), filename)

site_packages_sub_dirs = tuple(
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
Expand Down Expand Up @@ -135,6 +133,84 @@ def find_expected_sub_dir(sub_dir):
assert located_lib.found_via == "CUDA_PATH"


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported")
def test_locate_bitcode_lib_with_sm_arch_search_order(monkeypatch, tmp_path):
libname = "nvshmem_device"
sm_arch = "sm90"
filename = "libnvshmem_device_sm90.bc"

site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, filename)

conda_prefix = tmp_path / "conda-prefix"
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), filename)

cuda_home = tmp_path / "cuda-home"
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), filename)

site_packages_sub_dirs = tuple(
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
)

def find_expected_sub_dir(sub_dir):
assert sub_dir in site_packages_sub_dirs
if sub_dir == site_packages_sub_dirs[0]:
return [str(site_packages_lib_dir)]
return []

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
find_expected_sub_dir,
)
monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix))
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
monkeypatch.delenv("CUDA_PATH", raising=False)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == site_packages_path
assert located_lib.filename == filename
assert located_lib.found_via == "site-packages"
assert find_bitcode_lib(libname, sm_arch=sm_arch) == site_packages_path
os.remove(site_packages_path)
find_bitcode_lib_module.find_bitcode_lib.cache_clear()

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == conda_path
assert located_lib.filename == filename
assert located_lib.found_via == "conda"
os.remove(conda_path)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == cuda_home_path
assert located_lib.filename == filename
assert located_lib.found_via == "CUDA_PATH"


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported")
def test_find_bitcode_lib_cache_keeps_sm_arch_separate(monkeypatch, tmp_path):
libname = "nvshmem_device"
site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
sm80_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm80.bc")
sm90_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm90.bc")
sm90a_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm90a.bc")

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
lambda _sub_dir: [str(site_packages_lib_dir)],
)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("CUDA_HOME", raising=False)
monkeypatch.delenv("CUDA_PATH", raising=False)

assert find_bitcode_lib(libname, sm_arch="sm80") == sm80_path
assert find_bitcode_lib(libname, sm_arch="sm90") == sm90_path
assert find_bitcode_lib(libname, sm_arch="sm90a") == sm90a_path


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path):
cuda_home = tmp_path / "cuda-home"
Expand All @@ -156,12 +232,44 @@ def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(m
find_bitcode_lib("device")

message = str(exc_info.value)
expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_filename("device"))
expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_info("device")["filename"])
assert f"No such file: {expected_missing_file}" in message
assert f'listdir("{lib_dir}"):' in message
assert "README.txt" in message


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported")
def test_find_bitcode_lib_with_sm_arch_not_found_error_uses_arch_specific_filename(monkeypatch, tmp_path):
libname = "nvshmem_device"
sm_arch = "sm90"
expected_filename = "libnvshmem_device_sm90.bc"

cuda_home = tmp_path / "cuda-home"
lib_dir = _bitcode_lib_dir_under(cuda_home, libname)
lib_dir.mkdir(parents=True, exist_ok=True)
extra_file = lib_dir / "libnvshmem_device.bc"
extra_file.touch()

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
lambda _sub_dir: [],
)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
monkeypatch.delenv("CUDA_PATH", raising=False)

with pytest.raises(BitcodeLibNotFoundError, match=rf'Failure finding "{expected_filename}"') as exc_info:
find_bitcode_lib(libname, sm_arch=sm_arch)

message = str(exc_info.value)
expected_missing_file = os.path.join(str(lib_dir), expected_filename)
assert f"No such file: {expected_missing_file}" in message
assert f'listdir("{lib_dir}"):' in message
assert "libnvshmem_device.bc" in message


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch):
monkeypatch.setattr(
Expand All @@ -183,3 +291,34 @@ def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch):
def test_find_bitcode_lib_invalid_name():
with pytest.raises(ValueError, match="Unknown bitcode library"):
find_bitcode_lib_module.locate_bitcode_lib("invalid")


@pytest.mark.parametrize(
"find_fn",
[
find_bitcode_lib,
locate_bitcode_lib,
],
)
def test_bitcode_lib_sm_arch_is_keyword_only(find_fn):
with pytest.raises(TypeError):
find_fn("nvshmem_device", "sm90")


@pytest.mark.parametrize(
"sm_arch",
[
"",
"../sm90",
"compute90",
"sm_90",
"sm",
"sm90/extra",
"sm90A",
],
)
def test_find_bitcode_lib_invalid_sm_arch(sm_arch):
expected_pattern = find_bitcode_lib_module._SM_ARCH_PATTERN.pattern
with pytest.raises(ValueError) as exc_info:
find_bitcode_lib_module.locate_bitcode_lib("device", sm_arch=sm_arch)
assert f"must match {expected_pattern!r}" in str(exc_info.value)
Loading