Skip to content

Commit

Permalink
Merge pull request #66 from robinzyb/devel
Browse files Browse the repository at this point in the history
remove the old cp2kcube class and add type hints
  • Loading branch information
robinzyb committed Jul 12, 2024
2 parents 1883beb + 4c42f00 commit a88786e
Showing 1 changed file with 52 additions and 174 deletions.
226 changes: 52 additions & 174 deletions cp2kdata/cube/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,144 +2,19 @@
from copy import deepcopy

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from scipy import fft
from ase import Atom, Atoms
from monty.json import MSONable
import asciichartpy as acp

from cp2kdata.log import get_logger
from cp2kdata.utils import file_content, interpolate_spline
from cp2kdata.utils import au2A, au2eV
from cp2kdata.cell import Cp2kCell

# parse cp2kcube


class Cp2kCubeOld:
"""
timestep: unit ps
"""

def __init__(self, cube_file_name, timestep=0):
print("Warning: This is Cp2kCubeOld is deprecated after version 0.6.x, use Cp2kCube instead!")
print("Warning: to use old one, from cp2kdata.cube.cube import Cp2kCubeOld")
self.file = cube_file_name
self.timestep = timestep
self.cube_vals = self.read_cube_vals()
self.cell_x = self.grid_size[0]*self.grid_space[0]
self.cell_y = self.grid_size[1]*self.grid_space[1]
self.cell_z = self.grid_size[2]*self.grid_space[2]

@property
def num_atoms(self):
line = file_content(self.file, 2)
num_atoms = int(line.split()[0])
return num_atoms

@property
def grid_size(self):
# read grid point and grid size, unit: angstrom
content_list = file_content(self.file, (3, 6))
content_list = content_list.split()
num_x = int(content_list[0])
num_y = int(content_list[4])
num_z = int(content_list[8])
return num_x, num_y, num_z

@property
def grid_space(self):
# read grid point and grid size, unit: angstrom
content_list = file_content(self.file, (3, 6))
content_list = content_list.split()
step_x = float(content_list[1])*au2A
step_y = float(content_list[6])*au2A
step_z = float(content_list[11])*au2A
return step_x, step_y, step_z

def get_stc(self):
atom_list = []
for i in range(self.num_atoms):
stc_vals = file_content(self.file, (6+i, 6+i+1))
stc_vals = stc_vals.split()
atom = Atom(
symbol=int(stc_vals[0]),
position=(
float(stc_vals[2])*au2A, float(stc_vals[3])*au2A, float(stc_vals[4])*au2A)
)
atom_list.append(atom)

stc = Atoms(atom_list)
stc.set_cell([self.cell_x, self.cell_y, self.cell_z])
return stc

def read_cube_vals(self):
# read the cube value from file
cube_vals = file_content(self.file, (6+self.num_atoms,))
cube_vals = cube_vals.split()
cube_vals = np.array(cube_vals, dtype=float)
cube_vals = cube_vals.reshape(self.grid_size)
cube_vals = cube_vals*au2eV
return cube_vals

def get_pav(self, axis="z", interpolate=False):
# do the planar average along specific axis
if axis == 'x':
vals = self.cube_vals.mean(axis=(1, 2))
points = np.arange(0, self.grid_size[0])*self.grid_space[0]
length = self.grid_size[0]*self.grid_space[0]
elif axis == 'y':
vals = self.cube_vals.mean(axis=(0, 2))
points = np.arange(0, self.grid_size[1])*self.grid_space[1]
length = self.grid_size[1]*self.grid_space[1]
elif axis == 'z':
vals = self.cube_vals.mean(axis=(0, 1))
points = np.arange(0, self.grid_size[2])*self.grid_space[2]
length = self.grid_size[2]*self.grid_space[2]
else:
print("not such plane average style!")

# interpolate or note
if interpolate:
# set the last point same as first point
points = np.append(points, length)
vals = np.append(vals, vals[0])
new_points = np.linspace(0, length, 4097)[:-1]
new_points, new_vals = interpolate_spline(points, vals, new_points)
return new_points, new_vals
else:
return points, vals

def get_mav(self, l1, l2=0, ncov=1, interpolate=False):
axis = "z"
pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate)
theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, self.cell_z))
pav_fft = fft.fft(pav)
mav_fft = pav_fft*theta_1_fft*self.cell_z/len(pav_x)
if ncov == 2:
theta_2_fft = fft.fft(
self.square_wave_filter(pav_x, l2, self.cell_z))
mav_fft = mav_fft*theta_2_fft*self.cell_z/len(pav_x)
mav = fft.ifft(mav_fft)
return pav_x, np.real(mav)

def quick_plot(self, axis="z", interpolate=False, output_dir="./"):
x, y = self.get_pav(axis=axis, interpolate=interpolate)
plt.figure(figsize=(9, 9), dpi=100)
plt.plot(x, y, label=("PAV"+axis))
plt.xlabel(axis + " [A]")
plt.ylabel("Hartree [eV]")
plt.legend()
plt.savefig(os.path.join(output_dir, "pav.png"), dpi=100)

@staticmethod
def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y

logger = get_logger(__name__)

class Cp2kCube(MSONable):
# add MSONable use as_dict and from_dict
Expand All @@ -148,50 +23,39 @@ class Cp2kCube(MSONable):
"""

def __init__(self, fname: str = None, cube_vals: np.ndarray = None, cell: Cp2kCell = None, stc: Atoms = None):
print("Warning: This is New Cp2kCube Class, if you want to use old Cp2kCube")
print("try, from cp2kdata.cube.cube import Cp2kCubeOld")
print("New Cp2kCube return raw values in cp2k cube file")
print("that is, length in bohr and energy in hartree for potential file")
print("that is, length in bohr and density in e/bohr^3 for density file")
print("to convert unit: try from cp2kdata.utils import au2A, au2eV")

"""
New Cp2kCube return raw values in cp2k cube file
Units in Cp2kCube class
length in bohr and energy in hartree for potential file
length in bohr and density in e/bohr^3 for density file
to convert unit: try from cp2kdata.utils import au2A, au2eV
"""
self.file = fname

if cell is None:
self.cell = self.read_cell()
_grid_point = self._parse_grid_point(self.file)
_gs_matrix = self._parse_gs_matrix(self.file)
self.cell = self._parse_cell(_grid_point, _gs_matrix)
else:
self.cell = cell
if stc is None:
self.stc = self._parse_stc()
_num_atoms = self._parse_num_atoms(self.file)
self.stc = self._parse_stc(_num_atoms, self.file, self.cell)
else:
self.stc = stc

if cube_vals is None:
self.cube_vals = self.read_cube_vals(self.file,
self.cube_vals = self._parse_cube_vals(self.file,
self.num_atoms,
self.cell.grid_point
)
else:
self.cube_vals = cube_vals

def read_cell(self):
grid_point = self.read_grid_point(self.file)
gs_matrix = self.read_gs_matrix(self.file)
cell_param = gs_matrix*grid_point[:, np.newaxis]
return Cp2kCell(cell_param, grid_point, gs_matrix)

@property
def num_atoms(self):
return len(self.stc)

def _parse_num_atoms(self):
"""
be used to parse the number of atoms from the cube file only
"""
line = file_content(self.file, 2)
num_atoms = int(line.split()[0])
return num_atoms

def as_dict(self):
"""Returns data dict of Cp2kCube instance."""
data_dict = {
Expand Down Expand Up @@ -224,23 +88,6 @@ def __sub__(self, others):
raise RuntimeError("Unspported Class")
return other_copy

def _parse_stc(self):
num_atoms = self._parse_num_atoms()
atom_list = []
for i in range(num_atoms):
stc_vals = file_content(self.file, (6+i, 6+i+1))
stc_vals = stc_vals.split()
atom = Atom(
symbol=int(stc_vals[0]),
position=(
float(stc_vals[2])*au2A, float(stc_vals[3])*au2A, float(stc_vals[4])*au2A)
)
atom_list.append(atom)

stc = Atoms(atom_list)
stc.set_cell(self.cell.cell_matrix*au2A)
return stc

def get_stc(self):
return self.stc

Expand Down Expand Up @@ -380,7 +227,7 @@ def write_cube(self, fname, comments='#'):
fw.write(f'{self.cube_vals[i,j,k]:13.5E}')
if (k+1) % 6 == 0:
fw.write('\n')
# write a blank line after each z value
# write a new line character after each z value
if grid_point[2] % 6 != 0:
fw.write('\n')

Expand Down Expand Up @@ -421,7 +268,7 @@ def reduce_resolution(self, stride, axis='xyz'):
return new_cube

@staticmethod
def read_gs_matrix(fname):
def _parse_gs_matrix(fname: str) -> npt.NDArray[np.float64]:
content_list = file_content(fname, (3, 6))
content_list = content_list.split()

Expand All @@ -437,7 +284,7 @@ def read_gs_matrix(fname):
return gs_matrix

@staticmethod
def read_grid_point(fname):
def _parse_grid_point(fname: str) -> npt.NDArray[np.int64]:
# read grid point and grid size, unit: angstrom
content_list = file_content(fname, (3, 6))
content_list = content_list.split()
Expand All @@ -447,7 +294,38 @@ def read_grid_point(fname):
return np.array([num_x, num_y, num_z])

@staticmethod
def read_cube_vals(fname, num_atoms, grid_point):
def _parse_num_atoms(fname: str) -> int:
"""
be used to parse the number of atoms from the cube file only
"""
line = file_content(fname, 2)
num_atoms = int(line.split()[0])
return num_atoms

@staticmethod
def _parse_cell(grid_point: npt.NDArray[np.int64], gs_matrix: npt.NDArray[np.float64]) -> Cp2kCell:
cell_param = gs_matrix * grid_point[:, np.newaxis]
return Cp2kCell(cell_param, grid_point, gs_matrix)

@staticmethod
def _parse_stc(num_atoms: int, fname: str, cell: Cp2kCell) -> Atoms:
atom_list = []
for i in range(num_atoms):
stc_vals = file_content(fname, (6+i, 6+i+1))
stc_vals = stc_vals.split()
atom = Atom(
symbol=int(stc_vals[0]),
position=(
float(stc_vals[2])*au2A, float(stc_vals[3])*au2A, float(stc_vals[4])*au2A)
)
atom_list.append(atom)

stc = Atoms(atom_list)
stc.set_cell(cell.cell_matrix*au2A)
return stc

@staticmethod
def _parse_cube_vals(fname: str, num_atoms: int, grid_point: npt.NDArray[np.int64]) -> npt.NDArray[np.float64]:
# read the cube value from file
cube_vals = file_content(fname, (6+num_atoms,))
cube_vals = cube_vals.split()
Expand All @@ -456,7 +334,7 @@ def read_cube_vals(fname, num_atoms, grid_point):
return cube_vals

@staticmethod
def square_wave_filter(x, l, cell_z):
def square_wave_filter(x: npt.NDArray[np.float64], l: float, cell_z: float) -> npt.NDArray[np.float64]:
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
Expand Down

0 comments on commit a88786e

Please sign in to comment.