Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Connor Baker committed May 18, 2024
1 parent 0011e36 commit ee0b90c
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 244 deletions.
231 changes: 46 additions & 185 deletions pkgs/development/cuda-modules/package-sets.nix
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,20 @@
}:
let
inherit (config) data utils;
inherit (pkgs)
fetchurl
newScope
srcOnly
stdenv
;
inherit (lib.attrsets)
attrByPath
dontRecurseIntoAttrs
filterAttrs
mapAttrs
optionalAttrs
;
inherit (pkgs) newScope stdenv;
inherit (lib.attrsets) dontRecurseIntoAttrs mapAttrs optionalAttrs;
inherit (lib.customisation) makeScope;
inherit (lib.filesystem) packagesFromDirectoryRecursive;
inherit (lib.licenses) nvidiaCudaRedist;
inherit (lib.lists)
foldl'
map
naturalSort
optionals
unique
;
inherit (lib.meta) addMetaAttrs;
inherit (lib.options) mkOption;
inherit (lib.strings) removeSuffix versionAtLeast versionOlder;
inherit (lib.trivial) const flip pipe;
inherit (lib.strings) versionAtLeast versionOlder;
inherit (lib.trivial) const pipe;
inherit (lib.types) attrsOf raw;
inherit (lib.versions) major majorMinor;

Expand All @@ -41,151 +28,12 @@ let
# - Overrides for cutensor, etc.
# - Rename the platform type to redistPlatform to distinguish it from the Nix platform.

redistArch = utils.getRedistArch stdenv.hostPlatform.system;
isSupportedRedistArch = redistArch != "unsupported";

# Function to override a package.
overridePackage =
final:
let
# Maps redist name to package name to override.
overrides = utils.getOverrides final;
in
{ packageName, redistName, ... }:
package: package.overrideAttrs (overrides.${redistName}.${packageName} or { });

getNixPlatforms =
platform:
if platform == "source" then
# All platforms are supported for source packages.
[
"aarch64-linux" # Both SBSA (ARM servers) and Jetson
"x86_64-linux" # x86_64
"ppc64le-linux" # POWER
]
else
[ (utils.getNixPlatform platform) ];

# Function to build a redistributable package.
buildRedistPackage =
final: cudaMajorMinorPatchVersion:
meta@{
packageName,
redistName,
releaseInfo,
packageInfo,
version,
...
}:
let
package = pipe ./manifest-builder.nix [
# Build the package.
(flip final.callPackage {
inherit (meta) packageInfo packageName releaseInfo;
libPath = utils.getLibPath cudaMajorMinorPatchVersion packageInfo.feature.cudaVersionsInLib;
# The source is given by the tarball, which we unpack and use as a FOD.
src =
let
tarball = fetchurl {
inherit (packageInfo) sha256;
url =
if redistName == "tensorrt" then
utils.mkTensorRTURL version packageInfo.relativePath
else
utils.mkRedistURL redistName (utils.mkRelativePath meta);
};
unpacked = srcOnly {
__structuredAttrs = true;
strictDeps = true;
name = tarball.name + "-unpacked";
src = tarball;
outputHashMode = "recursive";
outputHash = packageInfo.narHash;
};
in
unpacked;
})
# Update the package license
(addMetaAttrs {
license = nvidiaCudaRedist // {
url =
let
licensePath =
if releaseInfo.licensePath != null then releaseInfo.licensePath else "${packageName}/LICENSE.txt";
in
"https://developer.download.nvidia.com/compute/${redistName}/redist/${licensePath}";
};
})
# Apply package-specific overrides, if they exist.
(overridePackage final meta)
];
in
package;

# Function to determine if a package should be accepted into the package set.
packageCudaVariantMatches =
cudaMajorMinorPatchVersion:
{
cudaVariant,
packageInfo,
platform,
redistName,
version,
...
}:
let
# One of the subdirectories of the lib directory contains a supported version for our version of CUDA.
# This is typically found with older versions of redistributables which don't use separate tarballs for each
# supported CUDA version.
hasSupportedCudaVersionInLib =
(utils.getLibPath cudaMajorMinorPatchVersion packageInfo.feature.cudaVersionsInLib) != null;
# There is a variant for the desired CUDA version.
isDesiredCudaVariant = cudaVariant == (utils.mkCudaVariant cudaMajorMinorPatchVersion);
in
attrByPath [ redistName ] (isDesiredCudaVariant || hasSupportedCudaVersionInLib) {
# CUBLASMP: Looks like it requires at least 11.8:
# https://docs.nvidia.com/cuda/cublasmp/getting_started/index.html
cublasmp = versionAtLeast cudaMajorMinorPatchVersion "11.8";

# CUDA: None of the CUDA redistributables have CUDA variants, but we only need to check that the release
# version matches the CUDA version we want.
cuda = version == cudaMajorMinorPatchVersion;

# CUDNN: Since cuDNN 8.5, it is possible to use the dynamic library for a CUDA release with any CUDA version
# in that major release series. For example, the cuDNN 8.5 dynamic library for CUDA 11.0 can be used with
# any CUDA 11.x release. (This functionality is not present for the CUDA 10.2 releases.)
# As such, it is enough that the cuda variant matches to accept the package.
cudnn = isDesiredCudaVariant;

# CUQUANTUM: Only available for CUDA 11.5 and later.
cuquantum =
if cudaVariant == "None" then
# This handles the case of pre-23.03 releases, which don't provide CUDA versions in lib
versionAtLeast cudaMajorMinorPatchVersion "11.5"
# And this handles the case of pre-23.06 releases, which do
|| hasSupportedCudaVersionInLib
else
isDesiredCudaVariant;

# CUTENSOR: Instead of providing CUDA variants, cuTensor provides multiple versions of the library nested
# in the lib directory. So long as one of the versions in cudaVersionsInLib is a prefix of the current CUDA
# version, we accept the package. We should have a more stringent version check, but no one has written
# a sidecar file mapping releases to supported CUDA versions.
cutensor = hasSupportedCudaVersionInLib;

# TODO: Add constraints for TensorRT

# TODO: These constraints are duplicated in the overrides files because the overrides package set is created
# with callPackage and evaluated strictly. As such, the constraints should exist outside of either of these
# places -- perhaps in config.data?
};

# Function to determine if a package should be accepted into the package set.
packagePlatformMatches =
platform: platform == "source" || (isSupportedRedistArch && platform == redistArch);

# Function to update the package set with the new package.
updatePackages =
let
redistArch = utils.getRedistArch stdenv.hostPlatform.system;
isSupportedRedistArch = redistArch != "unsupported";
in
{
packageInfo,
packageName,
Expand All @@ -194,7 +42,6 @@ let
releaseInfo,
...
}:
packages: package:
let
# Non-CUDA redist packages, like cuDNN, are multiplexed so we add a suffix to the attribute name.
# CUDA redist packages are always the same as the package name.
Expand All @@ -204,34 +51,46 @@ let
else
packageName;

# We want to make sure that packages exist even on platforms they cannot be built on or for, to avoid missing
# attribute errors.
#
# For each package we see, we do the following:
#
# - If a package of the same name does not exist in the package set:
# - If the package we see is for the current platform:
# - Add the package to the package set.
# - If the package we see is not for the current platform:
# - Add the package to the package set, but with a null `src` attribute and a `meta.platforms` attribute
# equal to the platform the package is for.
# - If a package of the same name exists in the package set:
# - If the package we see is for the current platform:
# - Override the `src` attribute and append the current platform to the `meta.platforms` attribute.
# - If the package we see is not for the current platform:
# - Append the current platform to the `meta.platforms` attribute.
# Package platform matches the current platform or the package is a source package.
packagePlatformMatches = platform == "source" || (isSupportedRedistArch && platform == redistArch);
in
packages: package:
let
packageForName =
name:
let
# # We want to make sure that packages exist even on platforms they cannot be built on or for, to avoid missing
# # attribute errors.
# #
# # For each package we see, we do the following:
# #
# # - If a package of the same name exists in the package set: (packageExists)
# packageExists = packages ? ${name};
# # - If the package we see is for the current platform: (packagePlatformMatches)
# # - Override the `src` attribute and append the current platform to the `meta.platforms` attribute.
# setSrcAndAppendPlatformToMetaPlatforms = packageExists && packagePlatformMatches;
# # - If the package we see is not for the current platform: (!packagePlatformMatches)
# # - Append the current platform to the `meta.platforms` attribute. (appendPlatformToMetaPlatforms)
# appendPlatformToMetaPlatforms = packageExists && !packagePlatformMatches;
# # - If a package of the same name does not exist in the package set: (!packageExists)
# # - If the package we see is for the current platform: (packagePlatformMatches)
# # - Add the package to the package set. (addPackageToPackageSet)
# addPackageToPackageSet = !packageExists && packagePlatformMatches;
# # - If the package we see is not for the current platform: (!packagePlatformMatches)
# # - Add the package to the package set, but with a null `src` attribute and a `meta.platforms` attribute
# # equal to the platform the package is for. (setSrcToNullAndMetaPlatformsToPlatform)
# setSrcToNullAndMetaPlatformsToPlatform = noPackageExists && !packagePlatformMatches;

# # A package exists in the package set, but it's older than the package we're processing.
# existingPackageIsOlder = packages ? ${name} && versionOlder packages.${name}.version package.version;

replacePackage =
# There's no entry for the package in the package set.
!(packages ? ${name})
# There is an entry but it's older than the package we're processing.
|| (versionOlder packages.${name}.version package.version);

package' = if replacePackage then package else packages.${name};

platformMatched = packagePlatformMatches platform;
in
pipe package' [
# Update the package.
Expand All @@ -240,16 +99,16 @@ let
# Only do the override if the platform matched or we're replacing the package.
# Otherwise leave it alone.
pkg.override (
optionalAttrs (platformMatched || replacePackage) {
optionalAttrs (packagePlatformMatches || replacePackage) {
src =
if platformMatched then
if packagePlatformMatches then
package.src
else if replacePackage then
null
else
builtins.throw "This should never happen.";
packageInfo =
if platformMatched then
if packagePlatformMatches then
packageInfo
else if replacePackage then
{ feature.outputs = [ "out" ]; }
Expand All @@ -265,7 +124,7 @@ let
meta.platforms =
let
existingPlatforms = optionals (!replacePackage) (prevAttrs.meta.platforms or [ ]);
newPlatforms = getNixPlatforms platform;
newPlatforms = utils.getNixPlatforms platform;
in
# TODO: This has at least O(n^2) complexity due to the `unique` call.
# And we do this for every package we process...
Expand Down Expand Up @@ -304,15 +163,17 @@ let
...
}:
let
package = buildRedistPackage final cudaMajorMinorPatchVersion meta;
package = utils.buildRedistPackage final meta;
in
if packageCudaVariantMatches cudaMajorMinorPatchVersion meta then
if utils.packageSupportsCudaVersion cudaMajorMinorPatchVersion meta then
updatePackages meta packages package
else
packages;
in
{
lib = dontRecurseIntoAttrs lib;
# TODO: pkgs.cudaPackages_11_8.pkgs.cudaPackages.cudaVersion should be 11.8, not pkgs.cudaPackages.cudaVersion.
# pkgs = dontRecurseIntoAttrs pkgs // { inherit (final) cudaPackages; };
pkgs = dontRecurseIntoAttrs pkgs;
data = dontRecurseIntoAttrs config.data;
utils = dontRecurseIntoAttrs config.utils;
Expand Down
Loading

0 comments on commit ee0b90c

Please sign in to comment.