Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Msd gpu supports #299

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e409465
GPU kernels for wicks
yfhuang93 Mar 25, 2024
1a9baff
Add GPU support for local_energy_wicks
yfhuang93 Mar 25, 2024
80fe33e
Add MSD-GPU support for green function
yfhuang93 Mar 25, 2024
bb7269c
Add MSD-GPU support for overlap.
yfhuang93 Mar 25, 2024
6cf05a3
Add GPU branches for cal_green_function and cal_overlap
yfhuang93 Mar 25, 2024
8da9440
Delete irrelevant timings.
yfhuang93 Mar 25, 2024
a2c1c58
Cast arrays to cupy for UHFwalkersParticleHole
yfhuang93 Mar 25, 2024
4d43b60
clean up the element-wise kernels
yfhuang93 Mar 27, 2024
11f919a
fix import bug
yfhuang93 Mar 27, 2024
ad05ff5
Stay up-to-date with green_function_noci
yfhuang93 Mar 28, 2024
7e34a33
update gpu supports for msd trial
yfhuang93 Mar 28, 2024
5a3c953
delete irrelevant imports and functions for timing and memory use pur…
yfhuang93 Mar 28, 2024
f4a9a5f
Merge branch 'develop' into msd_gpu
fdmalone Mar 30, 2024
d544e0d
Merge branch 'JoonhoLee-Group:develop' into msd_gpu
yfhuang93 Apr 1, 2024
1cddf6d
remove prints and _gpu suffix for msd greenfunction and overlap
yfhuang93 Apr 1, 2024
ab46e75
Fix reduce_CI_nfold parameters.
yfhuang93 Apr 15, 2024
10ecf73
Merge branch 'JoonhoLee-Group:develop' into msd_gpu
yfhuang93 Apr 28, 2024
f0004b7
fix a version problem
yfhuang93 Apr 28, 2024
b6a8a92
delete irrelevant imports
yfhuang93 May 17, 2024
5cd1d9d
Merge branch 'JoonhoLee-Group:develop' into msd_gpu
yfhuang93 Jun 24, 2024
5857bee
Merge branch 'develop' into msd_gpu
linusjoonho Aug 31, 2024
e24482e
fix requirements.txt for pytest
yfhuang93 Sep 3, 2024
5bdeb9d
delete cupy requirement
yfhuang93 Sep 5, 2024
9fcfb12
Merge branch 'JoonhoLee-Group:develop' into msd_gpu
yfhuang93 Sep 5, 2024
aa0570a
pylint test cleanup
yfhuang93 Sep 5, 2024
ab0dc67
Merge branch 'msd_gpu' of https://github.com/yfhuang93/ipie into msd_gpu
yfhuang93 Sep 5, 2024
f8b3282
skip pylint test for cupy import
yfhuang93 Sep 9, 2024
175927f
fix overlap.py version problem for NOCI
yfhuang93 Sep 9, 2024
fce3fbd
black reformat
yfhuang93 Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 190 additions & 25 deletions ipie/estimators/greens_function_multi_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
import scipy.linalg
from numba import jit

from ipie.estimators.kernels.cpu import wicks as wk
from ipie.legacy.estimators.greens_function import gab_mod
from ipie.config import config
from ipie.propagation.overlap import (
compute_determinants_batched,
get_cofactor_matrix_batched,
get_det_matrix_batched,
get_overlap_one_det_wicks,
reduce_to_CI_tensor,
get_cofactor_matrix_batched,
)
from ipie.utils.linalg import minor_mask
from ipie.propagation.overlap import get_det_matrix_batched, reduce_to_CI_tensor

if config.get_option("use_gpu"):
from ipie.estimators.kernels.gpu import wicks_gpu as wk
else:
from ipie.estimators.kernels.cpu import wicks as wk

from ipie.legacy.estimators.greens_function import gab_mod

from ipie.utils.backend import arraylib as xp


def greens_function_multi_det(walker_batch, trial, build_full=False):
Expand Down Expand Up @@ -145,7 +152,7 @@ def greens_function_multi_det_wicks(walker_batch, trial, build_full=False):
det : float64 / complex128
Determinant of overlap matrix.
"""
tot_ovlps = numpy.zeros(walker_batch.nwalkers, dtype=numpy.complex128)
tot_ovlps = xp.zeros(walker_batch.nwalkers, dtype=numpy.complex128)
nbasis = walker_batch.Ga.shape[-1]

walker_batch.Ga.fill(0.0 + 0.0j)
Expand All @@ -155,15 +162,15 @@ def greens_function_multi_det_wicks(walker_batch, trial, build_full=False):
phia = walker_batch.phia[iw] # walker wfn
phib = walker_batch.phib[iw] # walker wfn

Oalpha = numpy.dot(trial.psi0a.conj().T, phia)
sign_a, logdet_a = numpy.linalg.slogdet(Oalpha)
Oalpha = xp.dot(trial.psi0a.conj().T, phia)
sign_a, logdet_a = xp.linalg.slogdet(Oalpha)
logdet_b, sign_b = 0.0, 1.0
Obeta = numpy.dot(trial.psi0b.conj().T, phib)
sign_b, logdet_b = numpy.linalg.slogdet(Obeta)
Obeta = xp.dot(trial.psi0b.conj().T, phib)
sign_b, logdet_b = xp.linalg.slogdet(Obeta)

ovlp0 = sign_a * sign_b * numpy.exp(logdet_a + logdet_b)
walker_batch.det_ovlpas[iw, 0] = sign_a * numpy.exp(logdet_a)
walker_batch.det_ovlpbs[iw, 0] = sign_b * numpy.exp(logdet_b)
ovlp0 = sign_a * sign_b * xp.exp(logdet_a + logdet_b)
walker_batch.det_ovlpas[iw, 0] = sign_a * xp.exp(logdet_a)
walker_batch.det_ovlpbs[iw, 0] = sign_b * xp.exp(logdet_b)

# G0, G0H = gab_spin(trial.psi0, phi, nup, ndown)
G0a, G0Ha = gab_mod(trial.psi0a, phia)
Expand All @@ -172,8 +179,8 @@ def greens_function_multi_det_wicks(walker_batch, trial, build_full=False):
walker_batch.G0b[iw] = G0b
walker_batch.Ghalfa[iw] = G0Ha
walker_batch.Ghalfb[iw] = G0Hb
walker_batch.Q0a[iw] = numpy.eye(nbasis) - walker_batch.G0a[iw]
walker_batch.Q0b[iw] = numpy.eye(nbasis) - walker_batch.G0b[iw]
walker_batch.Q0a[iw] = xp.eye(nbasis) - walker_batch.G0a[iw]
walker_batch.Q0b[iw] = xp.eye(nbasis) - walker_batch.G0b[iw]

G0a = walker_batch.G0a[iw]
G0b = walker_batch.G0b[iw]
Expand Down Expand Up @@ -268,13 +275,13 @@ def greens_function_multi_det_wicks(walker_batch, trial, build_full=False):
) # 2 2

elif nex_a > 3:
det_a = numpy.zeros((nex_a, nex_a), dtype=numpy.complex128)
det_a = xp.zeros((nex_a, nex_a), dtype=numpy.complex128)
for iex in range(nex_a):
det_a[iex, iex] = G0a[trial.cre_a[jdet][iex], trial.anh_a[jdet][iex]]
for jex in range(iex + 1, nex_a):
det_a[iex, jex] = G0a[trial.cre_a[jdet][iex], trial.anh_a[jdet][jex]]
det_a[jex, iex] = G0a[trial.cre_a[jdet][jex], trial.anh_a[jdet][iex]]
cofactor = numpy.zeros((nex_a - 1, nex_a - 1), dtype=numpy.complex128)
cofactor = xp.zeros((nex_a - 1, nex_a - 1), dtype=numpy.complex128)
for iex in range(nex_a):
p = trial.cre_a[jdet][iex]
for jex in range(nex_a):
Expand Down Expand Up @@ -334,13 +341,13 @@ def greens_function_multi_det_wicks(walker_batch, trial, build_full=False):
) # 2 2

elif nex_b > 3:
det_b = numpy.zeros((nex_b, nex_b), dtype=numpy.complex128)
det_b = xp.zeros((nex_b, nex_b), dtype=numpy.complex128)
for iex in range(nex_b):
det_b[iex, iex] = G0b[trial.cre_b[jdet][iex], trial.anh_b[jdet][iex]]
for jex in range(iex + 1, nex_b):
det_b[iex, jex] = G0b[trial.cre_b[jdet][iex], trial.anh_b[jdet][jex]]
det_b[jex, iex] = G0b[trial.cre_b[jdet][jex], trial.anh_b[jdet][iex]]
cofactor = numpy.zeros((nex_b - 1, nex_b - 1), dtype=numpy.complex128)
cofactor = xp.zeros((nex_b - 1, nex_b - 1), dtype=numpy.complex128)
for iex in range(nex_b):
p = trial.cre_b[jdet][iex]
for jex in range(nex_b):
Expand Down Expand Up @@ -434,6 +441,7 @@ def build_CI_single_excitation_opt(walker_batch, trial, c_phasea_ovlpb, c_phaseb
-------
None, modifies walker_batch.CIa, and walker_batch.CIb inplace.
"""

if trial.cre_ex_a[1].shape[0] == 0:
pass
else:
Expand Down Expand Up @@ -587,6 +595,7 @@ def build_CI_double_excitation_opt(walker_batch, trial, c_phasea_ovlpb, c_phaseb
-------
None, modifies walker_batch.CIa, and walker_batch.CIb inplace.
"""

if trial.cre_ex_a[2].shape[0] == 0:
pass
else:
Expand Down Expand Up @@ -634,6 +643,7 @@ def build_CI_triple_excitation_opt(walker_batch, trial, c_phasea_ovlpb, c_phaseb
-------
None, modifies walker_batch.CIa, and walker_batch.CIb inplace.
"""

if trial.cre_ex_a[3].shape[0] == 0:
pass
else:
Expand Down Expand Up @@ -1034,13 +1044,15 @@ def build_CI_nfold_excitation_opt(nexcit, walker_batch, trial, c_phasea_ovlpb, c
-------
None, modifies walker_batch.CIa, and walker_batch.CIb inplace.
"""

ndets_a = len(trial.cre_ex_a[nexcit])
nwalkers = walker_batch.G0a.shape[0]
if ndets_a == 0:
pass
else:
det_mat = xp.zeros((nwalkers, ndets_a, nexcit, nexcit), dtype=numpy.complex128)
phases = c_phasea_ovlpb[:, trial.excit_map_a[nexcit]]
det_mat = numpy.zeros((nwalkers, ndets_a, nexcit, nexcit), dtype=numpy.complex128)

wk.build_det_matrix(
trial.cre_ex_a[nexcit],
trial.anh_ex_a[nexcit],
Expand All @@ -1049,7 +1061,10 @@ def build_CI_nfold_excitation_opt(nexcit, walker_batch, trial, c_phasea_ovlpb, c
walker_batch.Ghalfa,
det_mat,
)
cof_mat = numpy.zeros((nwalkers, ndets_a, nexcit - 1, nexcit - 1), dtype=numpy.complex128)

cof_mat = xp.zeros((nwalkers, ndets_a, nexcit - 1, nexcit - 1), dtype=numpy.complex128)
phases = c_phasea_ovlpb[:, trial.excit_map_a[nexcit]]

wk.reduce_CI_nfold(
trial.cre_ex_a[nexcit],
trial.anh_ex_a[nexcit],
Expand All @@ -1060,12 +1075,14 @@ def build_CI_nfold_excitation_opt(nexcit, walker_batch, trial, c_phasea_ovlpb, c
cof_mat,
walker_batch.CIa,
)

ndets_b = len(trial.cre_ex_b[nexcit])

if ndets_b == 0:
pass
else:
phases = c_phaseb_ovlpa[:, trial.excit_map_b[nexcit]]
det_mat = numpy.zeros((nwalkers, ndets_b, nexcit, nexcit), dtype=numpy.complex128)

det_mat = xp.zeros((nwalkers, ndets_b, nexcit, nexcit), dtype=numpy.complex128)
wk.build_det_matrix(
trial.cre_ex_b[nexcit],
trial.anh_ex_b[nexcit],
Expand All @@ -1074,7 +1091,9 @@ def build_CI_nfold_excitation_opt(nexcit, walker_batch, trial, c_phasea_ovlpb, c
walker_batch.Ghalfb,
det_mat,
)
cof_mat = numpy.zeros((nwalkers, ndets_b, nexcit - 1, nexcit - 1), dtype=numpy.complex128)
cof_mat = xp.zeros((nwalkers, ndets_b, nexcit - 1, nexcit - 1), dtype=numpy.complex128)
phases = c_phaseb_ovlpa[:, trial.excit_map_b[nexcit]]

wk.reduce_CI_nfold(
trial.cre_ex_b[nexcit],
trial.anh_ex_b[nexcit],
Expand Down Expand Up @@ -1110,6 +1129,26 @@ def contract_CI(Q0_act, CI, Ghalf, G):
G[iw] += numpy.dot(Q0_act[iw], numpy.dot(CI[iw], Ghalf[iw]))


def contract_CI_gpu(Q0_act, CI, Ghalf, G):
"""numba kernel to contract Q, CI and Ghalf to form G

Parameters
----------
Q0_act : numpy.ndarray
1-G.
CI : numpy.ndarray
Intermediate tensor.
Ghalf : numpy.ndarray
Walker half rotated Green's function
G: numpy.ndarray
Walker Green's function
Returns
-------
None, modifies G in place
"""
G += xp.einsum("wOa, wae, weo-> wOo", Q0_act, CI, Ghalf)
yfhuang93 marked this conversation as resolved.
Show resolved Hide resolved


def greens_function_multi_det_wicks_opt(walker_batch, trial, build_full=False):
"""Compute walker's green's function using Wick's theorem.

Expand All @@ -1124,7 +1163,6 @@ def greens_function_multi_det_wicks_opt(walker_batch, trial, build_full=False):
det : float64 / complex128
Determinant of overlap matrix.
"""

nbasis = walker_batch.Ga.shape[-1]

walker_batch.Ga.fill(0.0 + 0.0j)
Expand Down Expand Up @@ -1213,4 +1251,131 @@ def greens_function_multi_det_wicks_opt(walker_batch, trial, build_full=False):
walker_batch.Gb *= (ovlps0 / ovlps)[:, None, None]
walker_batch.det_ovlpas[:, 0] = signs_a * numpy.exp(logdets_a)
walker_batch.det_ovlpbs[:, 0] = signs_b * numpy.exp(logdets_b)

return ovlps


def greens_function_multi_det_wicks_opt_gpu(walker_batch, trial, build_full=False):
"""Compute walker's green's function using Wick's theorem.

Parameters
----------
walker_batch : object
MultiDetTrialWalkerBatch object.
trial : object
Trial wavefunction object.
Returns
-------
det : float64 / complex128
Determinant of overlap matrix.
"""

nbasis = walker_batch.Ga.shape[-1]

walker_batch.Ga.fill(0.0 + 0.0j)
walker_batch.Gb.fill(0.0 + 0.0j)

# Build reference Green's functions and overlaps
# Note abuse of naming convention this is really theta for the reference
# determinant.
G0a = xp.zeros((walker_batch.nwalkers, nbasis, nbasis), dtype=numpy.complex128)
G0b = xp.zeros((walker_batch.nwalkers, nbasis, nbasis), dtype=numpy.complex128)
ovlps0 = xp.zeros((walker_batch.nwalkers), dtype=numpy.complex128)
signs_a = xp.zeros_like(ovlps0)
signs_b = xp.zeros_like(ovlps0)
logdets_a = xp.zeros_like(ovlps0)
logdets_b = xp.zeros_like(ovlps0)

trial_psi0a_conj = xp.zeros_like(trial.psi0a.conj())
trial_psi0a_conj.set(trial.psi0a.conj())
trial_psi0b_conj = xp.zeros_like(trial.psi0b.conj())
trial_psi0b_conj.set(trial.psi0b.conj())

ovlp = xp.einsum("wex,xE->weE", walker_batch.phia.transpose(0, 2, 1), trial_psi0a_conj)
ovlp_inv = xp.linalg.inv(ovlp)
walker_batch.Ghalfa = xp.einsum("weE,wEo->weo", ovlp_inv, walker_batch.phia.transpose(0, 2, 1))
G0a = xp.einsum("ox,wxO->woO", trial_psi0a_conj, walker_batch.Ghalfa)
signs_a, logdets_a = xp.linalg.slogdet(ovlp)
signs_b = [1.0 for i in range(walker_batch.nwalkers)]
logdets_b = [0.0 for i in range(walker_batch.nwalkers)]
ovlp = xp.einsum("wex,xE->weE", walker_batch.phib.transpose(0, 2, 1), trial_psi0b_conj)
signs_b, logdets_b = xp.linalg.slogdet(ovlp)
ovlp_inv = xp.linalg.inv(ovlp)
walker_batch.Ghalfb = xp.einsum("weE,wEo->weo", ovlp_inv, walker_batch.phib.transpose(0, 2, 1))
G0b = xp.einsum("ox,wxO->woO", trial_psi0b_conj, walker_batch.Ghalfb)

trial_psi0a_conj = None
trial_psi0b_conj = None
ovlp = None
ovlp_inv = None

walker_batch.Ghalfa = xp.ascontiguousarray(walker_batch.Ghalfa)
walker_batch.Ghalfb = xp.ascontiguousarray(walker_batch.Ghalfb)

ovlps0 = signs_a * signs_b * xp.exp(logdets_a + logdets_b)
walker_batch.G0a = G0a
walker_batch.G0b = G0b
walker_batch.Q0a = xp.eye(nbasis)[None, :] - G0a
walker_batch.Q0b = xp.eye(nbasis)[None, :] - G0b
walker_batch.CIa.fill(0.0 + 0.0j)
walker_batch.CIb.fill(0.0 + 0.0j)

dets_a_full, dets_b_full = compute_determinants_batched(
walker_batch.Ghalfa, walker_batch.Ghalfb, trial
)

walker_batch.det_ovlpas = dets_a_full * xp.asarray(trial.phase_a[None, :]) # phase included
walker_batch.det_ovlpbs = dets_b_full * xp.asarray(trial.phase_b[None, :]) # phase included
ovlpa = walker_batch.det_ovlpas
ovlpb = walker_batch.det_ovlpbs

c_phasea_ovlpb = xp.einsum(
"wJ,J->wJ", ovlpb, xp.asarray(trial.phase_a * trial.coeffs.conj()), optimize=True
)
c_phaseb_ovlpa = xp.einsum(
"wJ,J->wJ", ovlpa, xp.asarray(trial.phase_b * trial.coeffs.conj()), optimize=True
)
# contribution 1 (disconnected diagrams)
ovlps = xp.einsum("wJ,J->w", ovlpa * ovlpb, xp.asarray(trial.coeffs.conj()), optimize=True)
walker_batch.Ga += xp.einsum("w,wpq->wpq", ovlps, G0a, optimize=True)
walker_batch.Gb += xp.einsum("w,wpq->wpq", ovlps, G0b, optimize=True)
# intermediates for contribution 2 (connected diagrams)

if trial.max_excite >= 1:
build_CI_single_excitation_opt(walker_batch, trial, c_phasea_ovlpb, c_phaseb_ovlpa)

if trial.max_excite >= 2:
build_CI_double_excitation_opt(walker_batch, trial, c_phasea_ovlpb, c_phaseb_ovlpa)

if trial.max_excite >= 3:
build_CI_triple_excitation_opt(walker_batch, trial, c_phasea_ovlpb, c_phaseb_ovlpa)

for iexcit in range(4, trial.max_excite + 1):
build_CI_nfold_excitation_opt(iexcit, walker_batch, trial, c_phasea_ovlpb, c_phaseb_ovlpa)

# contribution 2 (connected diagrams)
# Frozen orbitals not in original active space calculation but reincluded in
# AFQMC

act_orb = trial.act_orb_alpha
contract_CI_gpu(
walker_batch.Q0a[:, :, act_orb].copy(),
walker_batch.CIa,
walker_batch.Ghalfa[:, act_orb].copy(),
walker_batch.Ga,
)
act_orb = trial.act_orb_beta
contract_CI_gpu(
walker_batch.Q0b[:, :, act_orb].copy(),
walker_batch.CIb,
walker_batch.Ghalfb[:, act_orb].copy(),
walker_batch.Gb,
)
# multiplying everything by reference overlap
ovlps *= ovlps0
walker_batch.Ga *= (ovlps0 / ovlps)[:, None, None]
walker_batch.Gb *= (ovlps0 / ovlps)[:, None, None]
walker_batch.det_ovlpas[:, 0] = signs_a * xp.exp(logdets_a)
walker_batch.det_ovlpbs[:, 0] = signs_b * xp.exp(logdets_b)

return ovlps
Loading
Loading