Skip to content

Commit

Permalink
punt on type preservation
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jul 26, 2024
1 parent 91d1135 commit 6cf5cb6
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,32 +389,26 @@ def transform(
if col_dtype:
output_dtypes[col_name] = md.dtype(col_dtype).to_numpy

def empty_like(df, cols):
# Construct an empty DataFrame with the same dtypes as df

# TODO: constructing meta like this can loose dtype information for
# columns that are arbitrarily set to 'float64'. We should propagate
# dtype information along with column names in the columngroup graph.
# This currently only happens during intermediate 'fit' transforms,
# so as long as statoperators don't require dtype information on the
# DDF this doesn't matter all that much
def make_empty(df, cols):
# Construct an empty DataFrame

# TODO: constructing meta like this loses dtype information on the ddf
# and sets it all to 'float64'. We should propagate dtype information along
# with column names in the columngroup graph. This currently only
# happens during intermediate 'fit' transforms, so as long as statoperators
# don't require dtype information on the DDF this doesn't matter all that much
return df._constructor(
{
col: df._constructor_sliced(
[], dtype=df[col].dtype if col in df.columns else "float64"
)
for col in cols
}
{col: df._constructor_sliced([], dtype="float64") for col in cols}
)

if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame):
dtypes = output_dtypes
output_dtypes = empty_like(ddf._meta, columns)
output_dtypes = make_empty(ddf._meta, columns)
for col_name, col_dtype in dtypes.items():
output_dtypes[col_name] = output_dtypes[col_name].astype(col_dtype)

elif not output_dtypes:
output_dtypes = empty_like(ddf._meta, columns)
output_dtypes = make_empty(ddf._meta, columns)

return ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
Expand Down

0 comments on commit 6cf5cb6

Please sign in to comment.