Skip to content

Commit 849a208

Browse files
authored
Merge pull request #25 from UBC-MDS/cor_map
Added cor_map function and tests
2 parents 03937e6 + 0231a6b commit 849a208

File tree

4 files changed

+422
-5
lines changed

4 files changed

+422
-5
lines changed

eda_utils_py/eda_utils_py.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import pandas as pd
2+
import altair as alt
3+
from pandas.api.types import is_numeric_dtype
4+
15
def imputer(dataframe, strategy="mean", fill_value=None):
26
"""
37
A function to implement imputation functionality for completing missing values.
@@ -40,7 +44,8 @@ def imputer(dataframe, strategy="mean", fill_value=None):
4044
pass
4145

4246

43-
def cor_map(dataframe, num_col):
47+
def cor_map(dataframe, num_col, col_scheme = 'purpleorange'):
48+
4449
"""
4550
A function to implement a correlation heatmap including coefficients based on given numeric columns of a data frame.
4651
@@ -50,6 +55,11 @@ def cor_map(dataframe, num_col):
5055
The data frame to be used for EDA.
5156
num_col : list
5257
A list of string of column names with numeric data from the data frame.
58+
col_scheme : str, default = 'purpleorange'
59+
The color scheme of the heatmap desired, can only be one of the following;
60+
- 'purpleorange'
61+
- 'blueorange'
62+
- 'redblue'
5363
5464
Returns
5565
-------
@@ -69,10 +79,67 @@ def cor_map(dataframe, num_col):
6979
>>> })
7080
7181
>>> numerical_columns = ['SepalLengthCm','SepalWidthCm','PetalWidthCm']
72-
>>> cor_map(data, numerical_columns)
82+
>>> cor_map(data, numerical_columns, col_scheme = 'purpleorange')
7383
7484
"""
75-
pass
85+
86+
# Tests whether input data is of pd.DataFrame type
87+
if not isinstance(dataframe, pd.DataFrame):
88+
raise TypeError("The input dataframe must be of pd.DataFrame type")
89+
90+
# Tests whether input num_col is of type list
91+
if not isinstance(num_col, list):
92+
raise TypeError("The input num_col must be of type list")
93+
94+
# Tests whether values of num_col is of type str
95+
for x in num_col:
96+
if not isinstance(x, str):
97+
raise TypeError("The type of values in num_col must all be str")
98+
99+
# Tests whether input col_scheme is of type str
100+
if not isinstance(col_scheme, str):
101+
raise TypeError("col_scheme must be of type str")
102+
103+
# Tests whether col_scheme is one of three possible options
104+
if col_scheme not in ('purpleorange', 'blueorange', 'redblue'):
105+
raise Exception("This color scheme is not available, please use either 'purpleorange', 'blueorange' or 'redblue'")
106+
107+
# Tests whether all input columns exist in the input data
108+
for x in num_col:
109+
if x not in list(dataframe.columns):
110+
raise Exception("The given column names must exist in the given dataframe.")
111+
112+
# Tests whether all input columns in num_col are numeric columns
113+
for x in num_col:
114+
if not is_numeric_dtype(dataframe[x]):
115+
raise Exception("The given numerical columns must all be numeric.")
116+
117+
118+
corr_matrix = dataframe[num_col].corr().reset_index().melt('index')
119+
corr_matrix.columns = ['var1', 'var2', 'cor']
120+
121+
plot = alt.Chart(corr_matrix).mark_rect().encode(
122+
x=alt.X('var1', title=None),
123+
y=alt.Y('var2', title=None),
124+
color=alt.Color('cor',legend=None,
125+
scale = alt.Scale(scheme = col_scheme)),
126+
).properties(
127+
title = 'Correlation Matrix',
128+
width=400, height=400
129+
)
130+
131+
text = plot.mark_text(size=15).encode(
132+
text=alt.Text('cor', format=".2f"),
133+
color=alt.condition(
134+
"datum.cor > 0.5 | datum.cor < -0.3",
135+
alt.value('white'),
136+
alt.value('black')
137+
)
138+
)
139+
140+
cor_heatmap = plot + text
141+
142+
return cor_heatmap
76143

77144

78145
def outlier_identifier(dataframe, columns=None, method="trim"):

0 commit comments

Comments
 (0)