Skip to content

Commit

Permalink
Use property classes in single table QualityReport (#377)
Browse files Browse the repository at this point in the history
* remove raw_result

* tests refactoring 1

* data metadata validation

* tests 2

* data/metadata validation

* comments

* typo
  • Loading branch information
R-Palazzo committed Jun 30, 2023
1 parent fb351c6 commit 1f4df05
Show file tree
Hide file tree
Showing 6 changed files with 580 additions and 936 deletions.
2 changes: 1 addition & 1 deletion sdmetrics/reports/single_table/_properties/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _compute_average(self):
if not isinstance(self._details, pd.DataFrame) or 'Score' not in self._details.columns:
raise ValueError("The property details must be a DataFrame with a 'Score' column.")

return round(self._details['Score'].mean(), 3)
return round(self._details['Score'].mean(), 2)

def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=None):
"""Generate the _details dataframe for the property."""
Expand Down
255 changes: 97 additions & 158 deletions sdmetrics/reports/single_table/quality_report.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
"""Single table quality report."""

import itertools
import pickle
import sys
import warnings

import numpy as np
import pandas as pd
import pkg_resources
import tqdm

from sdmetrics.errors import IncomputableMetricError
from sdmetrics.reports.single_table.plot_utils import get_column_pairs_plot, get_column_shapes_plot
from sdmetrics.reports.utils import (
aggregate_metric_results, discretize_and_apply_metric, validate_single_table_inputs)
from sdmetrics.single_table import (
ContingencySimilarity, CorrelationSimilarity, KSComplement, TVComplement)
from sdmetrics.reports.single_table._properties import ColumnPairTrends, ColumnShapes
from sdmetrics.reports.utils import _validate_categorical_values


class QualityReport():
Expand All @@ -25,35 +18,72 @@ class QualityReport():
score along two properties - Column Shapes and Column Pair Trends.
"""

METRICS = {
'Column Shapes': [KSComplement, TVComplement],
'Column Pair Trends': [CorrelationSimilarity, ContingencySimilarity],
}

def __init__(self):
self._overall_quality_score = None
self._metric_results = {}
self._property_breakdown = {}
self._property_errors = {}
self.is_generated = False
self._properties = {
'Column Shapes': ColumnShapes(),
'Column Pair Trends': ColumnPairTrends()
}

def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
"""Validate that the metadata matches the data.
Raise an error if the column metadata does not match the column data.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
metadata (dict):
The metadata of the table.
"""
real_columns = set(real_data.columns)
synthetic_columns = set(synthetic_data.columns)
metadata_columns = set(metadata['columns'].keys())

missing_data = metadata_columns.difference(real_columns.union(synthetic_columns))
missing_metadata = real_columns.union(synthetic_columns).difference(metadata_columns)
missing_columns = missing_data.union(missing_metadata)

if missing_columns:
error_message = (
'The metadata does not match the data. The following columns are missing'
' in the real/synthetic data or in the metadata: '
f"{', '.join(sorted(missing_columns))}"
)
raise ValueError(error_message)

def validate(self, real_data, synthetic_data, metadata):
"""Validate the inputs.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
metadata (dict):
The metadata of the table.
"""
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

self._validate_metadata_matches_data(real_data, synthetic_data, metadata)
_validate_categorical_values(real_data, synthetic_data, metadata)

def _print_results(self, out=sys.stdout):
"""Print the quality report results."""
if pd.isna(self._overall_quality_score) & any(self._property_errors.values()):
out.write('\nOverall Quality Score: Error computing report.\n\n')
else:
out.write(
f'\nOverall Quality Score: {round(self._overall_quality_score * 100, 2)}%\n\n')
out.write(
f'\nOverall Quality Score: {round(self._overall_quality_score * 100, 2)}%\n\n'
)
out.write('Properties:\n')

if len(self._property_breakdown) > 0:
out.write('Properties:\n')

for prop, score in self._property_breakdown.items():
if not pd.isna(score):
out.write(f'{prop}: {round(score * 100, 2)}%\n')
elif self._property_errors[prop] > 0:
out.write(f'{prop}: Error computing property.\n')
else:
out.write(f'{prop}: NaN\n')
for property_name in self._properties:
property_score = self._properties[property_name]._compute_average()
out.write(
f'- {property_name}: {property_score * 100}%\n'
)

def generate(self, real_data, synthetic_data, metadata, verbose=True):
"""Generate report.
Expand All @@ -68,42 +98,33 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
verbose (bool):
Whether or not to print report summary and progress.
"""
validate_single_table_inputs(real_data, synthetic_data, metadata)

metrics = list(itertools.chain.from_iterable(self.METRICS.values()))

for metric in tqdm.tqdm(metrics, desc='Creating report', disable=(not verbose)):
try:
self._metric_results[metric.__name__] = metric.compute_breakdown(
real_data, synthetic_data, metadata)
except IncomputableMetricError:
# Metric is not compatible with this dataset.
self._metric_results[metric.__name__] = {}

existing_column_pairs = list(self._metric_results['ContingencySimilarity'].keys())
existing_column_pairs.extend(
list(self._metric_results['CorrelationSimilarity'].keys()))
additional_results = discretize_and_apply_metric(
real_data, synthetic_data, metadata, ContingencySimilarity, existing_column_pairs)
self._metric_results['ContingencySimilarity'].update(additional_results)
self.validate(real_data, synthetic_data, metadata)

self._property_breakdown = {}
for prop, metrics in self.METRICS.items():

num_prop_errors = 0
for metric in metrics:
_, num_metric_errors = aggregate_metric_results(
self._metric_results[metric.__name__])
num_prop_errors += num_metric_errors

self._property_breakdown[prop] = self.get_details(prop)['Quality Score'].mean()
self._property_errors[prop] = num_prop_errors
scores = []
for property_name in self._properties:
scores.append(self._properties[property_name].get_score(
real_data, synthetic_data, metadata)
)

self._overall_quality_score = np.nanmean(list(self._property_breakdown.values()))
self._overall_quality_score = np.nanmean(scores)
self.is_generated = True

if verbose:
self._print_results()

def _validate_property_generated(self, property_name):
"""Validate that the given property name and that the report has been generated."""
if property_name not in ['Column Shapes', 'Column Pair Trends']:
raise ValueError(
f"Invalid property name '{property_name}'."
" Valid property names are 'Column Shapes' and 'Column Pair Trends'."
)

if not self.is_generated:
raise ValueError(
'Quality report must be generated before getting details. Call `generate` first.'
)

def get_score(self):
"""Return the overall quality score.
Expand All @@ -114,15 +135,19 @@ def get_score(self):
return self._overall_quality_score

def get_properties(self):
"""Return the property score breakdown.
"""Return the property score.
Returns:
pandas.DataFrame
The property score breakdown.
The property score.
"""
name, score = [], []
for property_name in self._properties:
name.append(property_name)
score.append(self._properties[property_name]._compute_average())
return pd.DataFrame({
'Property': self._property_breakdown.keys(),
'Score': self._property_breakdown.values(),
'Property': name,
'Score': score,
})

def get_visualization(self, property_name):
Expand All @@ -136,109 +161,23 @@ def get_visualization(self, property_name):
plotly.graph_objects._figure.Figure
The visualization for the requested property.
"""
score_breakdowns = {
metric.__name__: self._metric_results[metric.__name__]
for metric in self.METRICS.get(property_name, [])
}

if property_name == 'Column Shapes':
fig = get_column_shapes_plot(score_breakdowns, self._property_breakdown[property_name])

elif property_name == 'Column Pair Trends':
fig = get_column_pairs_plot(
score_breakdowns,
self._property_breakdown[property_name],
)
self._validate_property_generated(property_name)

return fig
return self._properties[property_name].get_visualization()

def get_details(self, property_name):
"""Return the details for each score for the given property name.
"""Return the details table for the given property name.
Args:
property_name (str):
The name of the property to return score details for.
The name of the property to return details for.
Returns:
pandas.DataFrame
The score breakdown.
"""
columns = []
metrics = []
scores = []
errors = []
details = pd.DataFrame()

if property_name == 'Column Shapes':
for metric in self.METRICS[property_name]:
for column, score_breakdown in self._metric_results[metric.__name__].items():
if 'score' in score_breakdown and pd.isna(score_breakdown['score']):
continue

columns.append(column)
metrics.append(metric.__name__)
scores.append(score_breakdown.get('score', np.nan))
errors.append(score_breakdown.get('error', np.nan))

details = pd.DataFrame({
'Column': columns,
'Metric': metrics,
'Quality Score': scores,
})

elif property_name == 'Column Pair Trends':
real_scores = []
synthetic_scores = []
for metric in self.METRICS[property_name]:
for column_pair, score_breakdown in self._metric_results[metric.__name__].items():
columns.append(column_pair)
metrics.append(metric.__name__)
scores.append(score_breakdown.get('score', np.nan))
real_scores.append(score_breakdown.get('real', np.nan))
synthetic_scores.append(score_breakdown.get('synthetic', np.nan))
errors.append(score_breakdown.get('error', np.nan))

details = pd.DataFrame({
'Column 1': [col1 for col1, _ in columns],
'Column 2': [col2 for _, col2 in columns],
'Metric': metrics,
'Quality Score': scores,
'Real Correlation': real_scores,
'Synthetic Correlation': synthetic_scores,
})

if pd.Series(errors).notna().sum() > 0:
details['Error'] = errors

return details

def get_raw_result(self, metric_name):
"""Return the raw result of the given metric name.
Args:
metric_name (str):
The name of the desired metric.
self._validate_property_generated(property_name)

Returns:
dict
The raw results
"""
metrics = list(itertools.chain.from_iterable(self.METRICS.values()))
for metric in metrics:
if metric.__name__ == metric_name:
return [
{
'metric': {
'method': f'{metric.__module__}.{metric.__name__}',
'parameters': {},
},
'results': {
key: result for key, result in
self._metric_results[metric_name].items()
if not pd.isna(result.get('score', np.nan))
},
},
]
return self._properties[property_name]._details.copy()

def save(self, filepath):
"""Save this report instance to the given path using pickle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_get_score(self):
}
expected_details = pd.DataFrame(expected_details_dict)
pd.testing.assert_frame_equal(column_shape_property._details, expected_details)
assert score == 0.754
assert score == 0.75

def test_get_score_warnings(self, recwarn):
"""Test the ``get_score`` method when the metrics are raising erros for some columns."""
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_get_score_warnings(self, recwarn):
# Assert
details = column_shape_property._details
pd.testing.assert_series_equal(details['Error'], exp_error_serie, check_names=False)
assert score == 0.702
assert score == 0.70

def test_only_categorical_columns(self):
"""Test the ``get_score`` method when there are only categorical columns."""
Expand Down Expand Up @@ -127,4 +127,4 @@ def test_only_categorical_columns(self):
}
expected_details = pd.DataFrame(expected_details_dict)
pd.testing.assert_frame_equal(column_shape_property._details, expected_details)
assert score == 0.881
assert score == 0.88
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_get_score(self):
}
expected_details = pd.DataFrame(expected_details_dict)
pd.testing.assert_frame_equal(column_shape_property._details, expected_details)
assert score == 0.816
assert score == 0.82

def test_get_score_errors(self):
"""Test the ``get_score`` method when the metrics are raising errors for some columns."""
Expand Down Expand Up @@ -65,4 +65,4 @@ def test_get_score_errors(self):
assert column_names_nan == ['start_date', 'employability_perc']
assert error_messages[0] == expected_message_1
assert error_messages[1] == expected_message_2
assert score == 0.826
assert score == 0.83
Loading

0 comments on commit 1f4df05

Please sign in to comment.