{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Complete flow to generate a ML model for the HR Attrition dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Prepare the data:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import a custom transformer for preprocessing data based on feature definitions\n",
"from preprocessor import Preprocessor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" sample | \n",
" variable_type | \n",
" data_type | \n",
" feature_strategy | \n",
" hash_features | \n",
"
\n",
" \n",
" name | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Age | \n",
" 41 | \n",
" feature | \n",
" int | \n",
" scaling | \n",
" NaN | \n",
"
\n",
" \n",
" Attrition | \n",
" Yes | \n",
" target | \n",
" str | \n",
" none | \n",
" NaN | \n",
"
\n",
" \n",
" BusinessTravel | \n",
" Travel_Rarely | \n",
" feature | \n",
" str | \n",
" one hot encoding | \n",
" NaN | \n",
"
\n",
" \n",
" DailyRate | \n",
" 1102 | \n",
" feature | \n",
" int | \n",
" scaling | \n",
" NaN | \n",
"
\n",
" \n",
" Department | \n",
" Sales | \n",
" feature | \n",
" str | \n",
" one hot encoding | \n",
" NaN | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sample variable_type data_type feature_strategy \\\n",
"name \n",
"Age 41 feature int scaling \n",
"Attrition Yes target str none \n",
"BusinessTravel Travel_Rarely feature str one hot encoding \n",
"DailyRate 1102 feature int scaling \n",
"Department Sales feature str one hot encoding \n",
"\n",
" hash_features \n",
"name \n",
"Age NaN \n",
"Attrition NaN \n",
"BusinessTravel NaN \n",
"DailyRate NaN \n",
"Department NaN "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# Import feature definitions and data\n",
"sheets = pd.read_excel('Data/HR-Employee-Attrition.xlsx', sheet_name=[\"Feature Definitions\", \"Train-Test\"])\n",
"\n",
"# Create feature definitions data frame\n",
"features = sheets[\"Feature Definitions\"]\n",
"features.columns = [c.lower() for c in features.columns]\n",
"features.set_index(\"name\", append=False, inplace=True)\n",
"features.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Attrition | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Yes | \n",
"
\n",
" \n",
" 1 | \n",
" No | \n",
"
\n",
" \n",
" 2 | \n",
" No | \n",
"
\n",
" \n",
" 3 | \n",
" No | \n",
"
\n",
" \n",
" 4 | \n",
" No | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Attrition\n",
"0 Yes\n",
"1 No\n",
"2 No\n",
"3 No\n",
"4 No"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Setup the data dataframe\n",
"data = sheets[\"Train-Test\"]\n",
"\n",
"# Get the target features\n",
"target = features.loc[features[\"variable_type\"] == \"target\"]\n",
"target_name = target.index[0]\n",
"\n",
"# Get the target data\n",
"d_target = data.loc[:,[target_name]]\n",
"\n",
"d_target.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Get the features to be excluded from the model\n",
"exclusions = features['variable_type'].isin([\"excluded\", \"target\", \"identifier\"])\n",
"\n",
"excluded = features.loc[exclusions]\n",
"features = features.loc[~exclusions]\n",
"\n",
"# Remove excluded features from the data\n",
"data = data[features.index.tolist()]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Split the data into training and testing subsets\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(data, d_target, test_size=0.30, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" BusinessTravel | \n",
" DailyRate | \n",
" Department | \n",
" DistanceFromHome | \n",
" Education | \n",
" EducationField | \n",
" EnvironmentSatisfaction | \n",
" Gender | \n",
" HourlyRate | \n",
" ... | \n",
" PerformanceRating | \n",
" RelationshipSatisfaction | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
"
\n",
" \n",
" \n",
" \n",
" 893 | \n",
" 41 | \n",
" Travel_Frequently | \n",
" 1200 | \n",
" Research & Development | \n",
" 22 | \n",
" 3 | \n",
" Life Sciences | \n",
" 4 | \n",
" Female | \n",
" 75 | \n",
" ... | \n",
" 3 | \n",
" 1 | \n",
" 2 | \n",
" 12 | \n",
" 4 | \n",
" 2 | \n",
" 6 | \n",
" 2 | \n",
" 3 | \n",
" 3 | \n",
"
\n",
" \n",
" 115 | \n",
" 50 | \n",
" Travel_Frequently | \n",
" 809 | \n",
" Sales | \n",
" 12 | \n",
" 3 | \n",
" Marketing | \n",
" 3 | \n",
" Female | \n",
" 77 | \n",
" ... | \n",
" 3 | \n",
" 4 | \n",
" 0 | \n",
" 16 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 1 | \n",
"
\n",
" \n",
" 526 | \n",
" 27 | \n",
" Travel_Frequently | \n",
" 829 | \n",
" Sales | \n",
" 8 | \n",
" 1 | \n",
" Marketing | \n",
" 3 | \n",
" Male | \n",
" 84 | \n",
" ... | \n",
" 3 | \n",
" 2 | \n",
" 1 | \n",
" 5 | \n",
" 3 | \n",
" 3 | \n",
" 4 | \n",
" 2 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 175 | \n",
" 35 | \n",
" Travel_Frequently | \n",
" 138 | \n",
" Research & Development | \n",
" 2 | \n",
" 3 | \n",
" Medical | \n",
" 2 | \n",
" Female | \n",
" 37 | \n",
" ... | \n",
" 3 | \n",
" 4 | \n",
" 0 | \n",
" 10 | \n",
" 5 | \n",
" 3 | \n",
" 6 | \n",
" 2 | \n",
" 1 | \n",
" 2 | \n",
"
\n",
" \n",
" 63 | \n",
" 45 | \n",
" Travel_Rarely | \n",
" 193 | \n",
" Research & Development | \n",
" 6 | \n",
" 4 | \n",
" Other | \n",
" 4 | \n",
" Male | \n",
" 52 | \n",
" ... | \n",
" 3 | \n",
" 2 | \n",
" 0 | \n",
" 17 | \n",
" 3 | \n",
" 4 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 31 columns
\n",
"
"
],
"text/plain": [
" Age BusinessTravel DailyRate Department \\\n",
"893 41 Travel_Frequently 1200 Research & Development \n",
"115 50 Travel_Frequently 809 Sales \n",
"526 27 Travel_Frequently 829 Sales \n",
"175 35 Travel_Frequently 138 Research & Development \n",
"63 45 Travel_Rarely 193 Research & Development \n",
"\n",
" DistanceFromHome Education EducationField EnvironmentSatisfaction \\\n",
"893 22 3 Life Sciences 4 \n",
"115 12 3 Marketing 3 \n",
"526 8 1 Marketing 3 \n",
"175 2 3 Medical 2 \n",
"63 6 4 Other 4 \n",
"\n",
" Gender HourlyRate ... PerformanceRating RelationshipSatisfaction \\\n",
"893 Female 75 ... 3 1 \n",
"115 Female 77 ... 3 4 \n",
"526 Male 84 ... 3 2 \n",
"175 Female 37 ... 3 4 \n",
"63 Male 52 ... 3 2 \n",
"\n",
" StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n",
"893 2 12 4 \n",
"115 0 16 3 \n",
"526 1 5 3 \n",
"175 0 10 5 \n",
"63 0 17 3 \n",
"\n",
" WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"893 2 6 2 \n",
"115 3 2 2 \n",
"526 3 4 2 \n",
"175 3 6 2 \n",
"63 4 0 0 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager \n",
"893 3 3 \n",
"115 2 1 \n",
"526 1 1 \n",
"175 1 2 \n",
"63 0 0 \n",
"\n",
"[5 rows x 31 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" BusinessTravel_Non-Travel | \n",
" BusinessTravel_Travel_Frequently | \n",
" BusinessTravel_Travel_Rarely | \n",
" Department_Human Resources | \n",
" Department_Research & Development | \n",
" Department_Sales | \n",
" Education_1 | \n",
" Education_2 | \n",
" Education_3 | \n",
" Education_4 | \n",
" ... | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
" JobRole0 | \n",
" JobRole1 | \n",
" JobRole2 | \n",
" JobRole3 | \n",
"
\n",
" \n",
" \n",
" \n",
" 893 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 0.083789 | \n",
" 0.927695 | \n",
" -0.158869 | \n",
" -0.613794 | \n",
" 0.232242 | \n",
" -0.328203 | \n",
" -0.994429 | \n",
" -0.451462 | \n",
" -0.816773 | \n",
" -1.706519 | \n",
"
\n",
" \n",
" 115 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 0.591154 | \n",
" 0.153090 | \n",
" -0.807584 | \n",
" -0.613794 | \n",
" -0.071434 | \n",
" -0.882752 | \n",
" 1.265490 | \n",
" -0.034487 | \n",
" -0.816773 | \n",
" 0.899445 | \n",
"
\n",
" \n",
" 526 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" ... | \n",
" -0.804100 | \n",
" 0.153090 | \n",
" -0.483227 | \n",
" -0.613794 | \n",
" -0.375111 | \n",
" -0.882752 | \n",
" 1.265490 | \n",
" -0.034487 | \n",
" -0.816773 | \n",
" 0.899445 | \n",
"
\n",
" \n",
" 175 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" -0.169894 | \n",
" 1.702300 | \n",
" -0.158869 | \n",
" -0.613794 | \n",
" -0.375111 | \n",
" -0.605477 | \n",
" -0.994429 | \n",
" -1.702387 | \n",
" 1.808367 | \n",
" 0.899445 | \n",
"
\n",
" \n",
" 63 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" ... | \n",
" 0.717995 | \n",
" 0.153090 | \n",
" -1.131942 | \n",
" -1.175408 | \n",
" -0.678787 | \n",
" -1.160026 | \n",
" 0.512184 | \n",
" 0.799463 | \n",
" 0.058274 | \n",
" 0.899445 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 381 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" ... | \n",
" -1.311464 | \n",
" 0.153090 | \n",
" -0.969763 | \n",
" -1.175408 | \n",
" -0.678787 | \n",
" -1.160026 | \n",
" 2.018796 | \n",
" 1.633413 | \n",
" -0.816773 | \n",
" 0.030790 | \n",
"
\n",
" \n",
" 292 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 0.210630 | \n",
" 0.153090 | \n",
" 0.003310 | \n",
" 0.790241 | \n",
" -0.375111 | \n",
" 0.780896 | \n",
" 1.265490 | \n",
" -0.034487 | \n",
" -0.816773 | \n",
" 0.899445 | \n",
"
\n",
" \n",
" 1083 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" ... | \n",
" -1.184623 | \n",
" 0.153090 | \n",
" -0.807584 | \n",
" -0.613794 | \n",
" -0.678787 | \n",
" -0.605477 | \n",
" -0.241123 | \n",
" 1.216438 | \n",
" -0.816773 | \n",
" -0.837864 | \n",
"
\n",
" \n",
" 551 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" -1.311464 | \n",
" 0.153090 | \n",
" -0.969763 | \n",
" -1.175408 | \n",
" -0.678787 | \n",
" -1.160026 | \n",
" -0.994429 | \n",
" -0.451462 | \n",
" -0.816773 | \n",
" -1.706519 | \n",
"
\n",
" \n",
" 141 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" ... | \n",
" -0.677258 | \n",
" 0.153090 | \n",
" -0.483227 | \n",
" -0.332987 | \n",
" -0.375111 | \n",
" -0.605477 | \n",
" -0.241123 | \n",
" 1.216438 | \n",
" -0.816773 | \n",
" -0.837864 | \n",
"
\n",
" \n",
"
\n",
"
399 rows × 74 columns
\n",
"
"
],
"text/plain": [
" BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n",
"893 0 1 \n",
"115 0 1 \n",
"526 0 1 \n",
"175 0 1 \n",
"63 0 0 \n",
"... ... ... \n",
"381 0 0 \n",
"292 0 0 \n",
"1083 0 0 \n",
"551 0 0 \n",
"141 0 1 \n",
"\n",
" BusinessTravel_Travel_Rarely Department_Human Resources \\\n",
"893 0 0 \n",
"115 0 0 \n",
"526 0 0 \n",
"175 0 0 \n",
"63 1 0 \n",
"... ... ... \n",
"381 1 1 \n",
"292 1 0 \n",
"1083 1 0 \n",
"551 1 0 \n",
"141 0 0 \n",
"\n",
" Department_Research & Development Department_Sales Education_1 \\\n",
"893 1 0 0 \n",
"115 0 1 0 \n",
"526 0 1 1 \n",
"175 1 0 0 \n",
"63 1 0 0 \n",
"... ... ... ... \n",
"381 0 0 0 \n",
"292 0 1 0 \n",
"1083 0 1 1 \n",
"551 1 0 0 \n",
"141 0 1 0 \n",
"\n",
" Education_2 Education_3 Education_4 ... TotalWorkingYears \\\n",
"893 0 1 0 ... 0.083789 \n",
"115 0 1 0 ... 0.591154 \n",
"526 0 0 0 ... -0.804100 \n",
"175 0 1 0 ... -0.169894 \n",
"63 0 0 1 ... 0.717995 \n",
"... ... ... ... ... ... \n",
"381 1 0 0 ... -1.311464 \n",
"292 0 1 0 ... 0.210630 \n",
"1083 0 0 0 ... -1.184623 \n",
"551 0 1 0 ... -1.311464 \n",
"141 0 0 1 ... -0.677258 \n",
"\n",
" TrainingTimesLastYear YearsAtCompany YearsInCurrentRole \\\n",
"893 0.927695 -0.158869 -0.613794 \n",
"115 0.153090 -0.807584 -0.613794 \n",
"526 0.153090 -0.483227 -0.613794 \n",
"175 1.702300 -0.158869 -0.613794 \n",
"63 0.153090 -1.131942 -1.175408 \n",
"... ... ... ... \n",
"381 0.153090 -0.969763 -1.175408 \n",
"292 0.153090 0.003310 0.790241 \n",
"1083 0.153090 -0.807584 -0.613794 \n",
"551 0.153090 -0.969763 -1.175408 \n",
"141 0.153090 -0.483227 -0.332987 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager JobRole0 JobRole1 \\\n",
"893 0.232242 -0.328203 -0.994429 -0.451462 \n",
"115 -0.071434 -0.882752 1.265490 -0.034487 \n",
"526 -0.375111 -0.882752 1.265490 -0.034487 \n",
"175 -0.375111 -0.605477 -0.994429 -1.702387 \n",
"63 -0.678787 -1.160026 0.512184 0.799463 \n",
"... ... ... ... ... \n",
"381 -0.678787 -1.160026 2.018796 1.633413 \n",
"292 -0.375111 0.780896 1.265490 -0.034487 \n",
"1083 -0.678787 -0.605477 -0.241123 1.216438 \n",
"551 -0.678787 -1.160026 -0.994429 -0.451462 \n",
"141 -0.375111 -0.605477 -0.241123 1.216438 \n",
"\n",
" JobRole2 JobRole3 \n",
"893 -0.816773 -1.706519 \n",
"115 -0.816773 0.899445 \n",
"526 -0.816773 0.899445 \n",
"175 1.808367 0.899445 \n",
"63 0.058274 0.899445 \n",
"... ... ... \n",
"381 -0.816773 0.030790 \n",
"292 -0.816773 0.899445 \n",
"1083 -0.816773 -0.837864 \n",
"551 -0.816773 -1.706519 \n",
"141 -0.816773 -0.837864 \n",
"\n",
"[399 rows x 74 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test the preprocessor\n",
"prep = Preprocessor(features, return_type='df').fit(X_train)\n",
"prep.transform(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set up machine learning pipelines, fit and score the models:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Logistic Regression pipeline test accuracy: 0.905\n",
"Random Forest pipeline test accuracy: 0.862\n",
"Classifier with best accuracy: Logistic Regression\n",
"Saved Logistic Regression pipeline to file\n"
]
}
],
"source": [
"# Set up a pipeline to run the ML flow\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import pickle\n",
"\n",
"# Construct the pipelines\n",
"pipe_lr = Pipeline([('prep', Preprocessor(features, return_type='df')), ('clf', LogisticRegression(solver='lbfgs', random_state=42))])\n",
"pipe_rf = Pipeline([('prep', Preprocessor(features, return_type='df')), ('clf', RandomForestClassifier(n_estimators=10, random_state=42))])\n",
"\n",
"# List of pipelines for ease of iteration\n",
"pipelines = [pipe_lr, pipe_rf]\n",
"\n",
"# Dictionary of pipelines and classifier types for ease of reference\n",
"pipe_dict = {0: 'Logistic Regression', 1: 'Random Forest'}\n",
"\n",
"# Fit the pipelines\n",
"for pipe in pipelines:\n",
" pipe.fit(X_train, y_train.values.ravel())\n",
"\n",
"# Compare accuracies\n",
"for idx, val in enumerate(pipelines):\n",
" print('%s pipeline test accuracy: %.3f' % (pipe_dict[idx], val.score(X_test, y_test)))\n",
"\n",
"# Identify the most accurate model on test data\n",
"best_acc = 0.0\n",
"best_clf = 0\n",
"best_pipe = ''\n",
"for idx, val in enumerate(pipelines):\n",
" if val.score(X_test, y_test) > best_acc:\n",
" best_acc = val.score(X_test, y_test.values.ravel())\n",
" best_pipe = val\n",
" best_clf = idx\n",
"print('Classifier with best accuracy: %s' % pipe_dict[best_clf])\n",
"\n",
"# Save pipeline to file\n",
"with open('HR-Attrition-v1.pkl', 'wb') as file:\n",
" pickle.dump(best_pipe, file)\n",
" print('Saved %s pipeline to file' % pipe_dict[best_clf])\n",
"\n",
"# Also save the preprocessor and model as separate files\n",
"with open('HR-Attrition-v1-prep.pkl', 'wb') as file:\n",
" pickle.dump(best_pipe.named_steps['prep'], file)\n",
"with open('HR-Attrition-v1-clf.pkl', 'wb') as file:\n",
" pickle.dump(best_pipe.named_steps['clf'], file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Validate the saved model:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(memory=None,\n",
" steps=[('prep',\n",
" ),\n",
" ('clf',\n",
" LogisticRegression(C=1.0, class_weight=None, dual=False,\n",
" fit_intercept=True, intercept_scaling=1,\n",
" l1_ratio=None, max_iter=100,\n",
" multi_class='warn', n_jobs=None,\n",
" penalty='l2', random_state=42,\n",
" solver='lbfgs', tol=0.0001, verbose=0,\n",
" warm_start=False))],\n",
" verbose=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load the saved pipeline from disk\n",
"with open('HR-Attrition-v1.pkl', 'rb') as file:\n",
" model = pickle.load(file)\n",
"\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" Attrition | \n",
" BusinessTravel | \n",
" DailyRate | \n",
" Department | \n",
" DistanceFromHome | \n",
" Education | \n",
" EducationField | \n",
" EmployeeCount | \n",
" EmployeeNumber | \n",
" ... | \n",
" RelationshipSatisfaction | \n",
" StandardHours | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 37 | \n",
" Yes | \n",
" Travel_Rarely | \n",
" 1373 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Other | \n",
" 1 | \n",
" 4 | \n",
" ... | \n",
" 2 | \n",
" 80 | \n",
" 0 | \n",
" 7 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 27 | \n",
" No | \n",
" Travel_Rarely | \n",
" 591 | \n",
" Research & Development | \n",
" 2 | \n",
" 1 | \n",
" Medical | \n",
" 1 | \n",
" 7 | \n",
" ... | \n",
" 4 | \n",
" 80 | \n",
" 1 | \n",
" 6 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
"
\n",
" \n",
" 2 | \n",
" 32 | \n",
" No | \n",
" Travel_Frequently | \n",
" 1005 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Life Sciences | \n",
" 1 | \n",
" 8 | \n",
" ... | \n",
" 3 | \n",
" 80 | \n",
" 0 | \n",
" 8 | \n",
" 2 | \n",
" 2 | \n",
" 7 | \n",
" 7 | \n",
" 3 | \n",
" 6 | \n",
"
\n",
" \n",
" 3 | \n",
" 53 | \n",
" No | \n",
" Travel_Rarely | \n",
" 1282 | \n",
" Research & Development | \n",
" 5 | \n",
" 3 | \n",
" Other | \n",
" 1 | \n",
" 32 | \n",
" ... | \n",
" 4 | \n",
" 80 | \n",
" 1 | \n",
" 26 | \n",
" 3 | \n",
" 2 | \n",
" 14 | \n",
" 13 | \n",
" 4 | \n",
" 8 | \n",
"
\n",
" \n",
" 4 | \n",
" 43 | \n",
" No | \n",
" Travel_Rarely | \n",
" 1273 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Medical | \n",
" 1 | \n",
" 46 | \n",
" ... | \n",
" 4 | \n",
" 80 | \n",
" 2 | \n",
" 6 | \n",
" 3 | \n",
" 2 | \n",
" 5 | \n",
" 3 | \n",
" 1 | \n",
" 4 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 135 | \n",
" 34 | \n",
" No | \n",
" Travel_Rarely | \n",
" 704 | \n",
" Sales | \n",
" 28 | \n",
" 3 | \n",
" Marketing | \n",
" 1 | \n",
" 2035 | \n",
" ... | \n",
" 4 | \n",
" 80 | \n",
" 2 | \n",
" 8 | \n",
" 2 | \n",
" 3 | \n",
" 8 | \n",
" 7 | \n",
" 1 | \n",
" 7 | \n",
"
\n",
" \n",
" 136 | \n",
" 36 | \n",
" No | \n",
" Travel_Rarely | \n",
" 1120 | \n",
" Sales | \n",
" 11 | \n",
" 4 | \n",
" Marketing | \n",
" 1 | \n",
" 2045 | \n",
" ... | \n",
" 1 | \n",
" 80 | \n",
" 1 | \n",
" 8 | \n",
" 2 | \n",
" 2 | \n",
" 6 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 137 | \n",
" 29 | \n",
" No | \n",
" Travel_Rarely | \n",
" 468 | \n",
" Research & Development | \n",
" 28 | \n",
" 4 | \n",
" Medical | \n",
" 1 | \n",
" 2054 | \n",
" ... | \n",
" 2 | \n",
" 80 | \n",
" 0 | \n",
" 5 | \n",
" 3 | \n",
" 1 | \n",
" 5 | \n",
" 4 | \n",
" 0 | \n",
" 4 | \n",
"
\n",
" \n",
" 138 | \n",
" 39 | \n",
" No | \n",
" Travel_Rarely | \n",
" 722 | \n",
" Sales | \n",
" 24 | \n",
" 1 | \n",
" Marketing | \n",
" 1 | \n",
" 2056 | \n",
" ... | \n",
" 1 | \n",
" 80 | \n",
" 1 | \n",
" 21 | \n",
" 2 | \n",
" 2 | \n",
" 20 | \n",
" 9 | \n",
" 9 | \n",
" 6 | \n",
"
\n",
" \n",
" 139 | \n",
" 36 | \n",
" No | \n",
" Travel_Frequently | \n",
" 884 | \n",
" Research & Development | \n",
" 23 | \n",
" 2 | \n",
" Medical | \n",
" 1 | \n",
" 2061 | \n",
" ... | \n",
" 3 | \n",
" 80 | \n",
" 1 | \n",
" 17 | \n",
" 3 | \n",
" 3 | \n",
" 5 | \n",
" 2 | \n",
" 0 | \n",
" 3 | \n",
"
\n",
" \n",
"
\n",
"
140 rows × 35 columns
\n",
"
"
],
"text/plain": [
" Age Attrition BusinessTravel DailyRate Department \\\n",
"0 37 Yes Travel_Rarely 1373 Research & Development \n",
"1 27 No Travel_Rarely 591 Research & Development \n",
"2 32 No Travel_Frequently 1005 Research & Development \n",
"3 53 No Travel_Rarely 1282 Research & Development \n",
"4 43 No Travel_Rarely 1273 Research & Development \n",
".. ... ... ... ... ... \n",
"135 34 No Travel_Rarely 704 Sales \n",
"136 36 No Travel_Rarely 1120 Sales \n",
"137 29 No Travel_Rarely 468 Research & Development \n",
"138 39 No Travel_Rarely 722 Sales \n",
"139 36 No Travel_Frequently 884 Research & Development \n",
"\n",
" DistanceFromHome Education EducationField EmployeeCount \\\n",
"0 2 2 Other 1 \n",
"1 2 1 Medical 1 \n",
"2 2 2 Life Sciences 1 \n",
"3 5 3 Other 1 \n",
"4 2 2 Medical 1 \n",
".. ... ... ... ... \n",
"135 28 3 Marketing 1 \n",
"136 11 4 Marketing 1 \n",
"137 28 4 Medical 1 \n",
"138 24 1 Marketing 1 \n",
"139 23 2 Medical 1 \n",
"\n",
" EmployeeNumber ... RelationshipSatisfaction StandardHours \\\n",
"0 4 ... 2 80 \n",
"1 7 ... 4 80 \n",
"2 8 ... 3 80 \n",
"3 32 ... 4 80 \n",
"4 46 ... 4 80 \n",
".. ... ... ... ... \n",
"135 2035 ... 4 80 \n",
"136 2045 ... 1 80 \n",
"137 2054 ... 2 80 \n",
"138 2056 ... 1 80 \n",
"139 2061 ... 3 80 \n",
"\n",
" StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n",
"0 0 7 3 \n",
"1 1 6 3 \n",
"2 0 8 2 \n",
"3 1 26 3 \n",
"4 2 6 3 \n",
".. ... ... ... \n",
"135 2 8 2 \n",
"136 1 8 2 \n",
"137 0 5 3 \n",
"138 1 21 2 \n",
"139 1 17 3 \n",
"\n",
" WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"0 3 0 0 \n",
"1 3 2 2 \n",
"2 2 7 7 \n",
"3 2 14 13 \n",
"4 2 5 3 \n",
".. ... ... ... \n",
"135 3 8 7 \n",
"136 2 6 3 \n",
"137 1 5 4 \n",
"138 2 20 9 \n",
"139 3 5 2 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager \n",
"0 0 0 \n",
"1 2 2 \n",
"2 3 6 \n",
"3 4 8 \n",
"4 1 4 \n",
".. ... ... \n",
"135 1 7 \n",
"136 0 0 \n",
"137 0 4 \n",
"138 9 6 \n",
"139 0 3 \n",
"\n",
"[140 rows x 35 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load additional data to test the saved model\n",
"validation = pd.read_excel('Data/HR-Employee-Attrition.xlsx', sheet_name=\"Validate\")\n",
"validation"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Attrition | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Yes | \n",
"
\n",
" \n",
" 1 | \n",
" No | \n",
"
\n",
" \n",
" 2 | \n",
" No | \n",
"
\n",
" \n",
" 3 | \n",
" No | \n",
"
\n",
" \n",
" 4 | \n",
" No | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Attrition\n",
"0 Yes\n",
"1 No\n",
"2 No\n",
"3 No\n",
"4 No"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get the targets\n",
"v_target = validation.loc[:,[target_name]]\n",
"v_target.head()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age | \n",
" BusinessTravel | \n",
" DailyRate | \n",
" Department | \n",
" DistanceFromHome | \n",
" Education | \n",
" EducationField | \n",
" EnvironmentSatisfaction | \n",
" Gender | \n",
" HourlyRate | \n",
" ... | \n",
" PerformanceRating | \n",
" RelationshipSatisfaction | \n",
" StockOptionLevel | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" WorkLifeBalance | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 37 | \n",
" Travel_Rarely | \n",
" 1373 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Other | \n",
" 4 | \n",
" Male | \n",
" 92 | \n",
" ... | \n",
" 3 | \n",
" 2 | \n",
" 0 | \n",
" 7 | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 27 | \n",
" Travel_Rarely | \n",
" 591 | \n",
" Research & Development | \n",
" 2 | \n",
" 1 | \n",
" Medical | \n",
" 1 | \n",
" Male | \n",
" 40 | \n",
" ... | \n",
" 3 | \n",
" 4 | \n",
" 1 | \n",
" 6 | \n",
" 3 | \n",
" 3 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
" 2 | \n",
"
\n",
" \n",
" 2 | \n",
" 32 | \n",
" Travel_Frequently | \n",
" 1005 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Life Sciences | \n",
" 4 | \n",
" Male | \n",
" 79 | \n",
" ... | \n",
" 3 | \n",
" 3 | \n",
" 0 | \n",
" 8 | \n",
" 2 | \n",
" 2 | \n",
" 7 | \n",
" 7 | \n",
" 3 | \n",
" 6 | \n",
"
\n",
" \n",
" 3 | \n",
" 53 | \n",
" Travel_Rarely | \n",
" 1282 | \n",
" Research & Development | \n",
" 5 | \n",
" 3 | \n",
" Other | \n",
" 3 | \n",
" Female | \n",
" 58 | \n",
" ... | \n",
" 3 | \n",
" 4 | \n",
" 1 | \n",
" 26 | \n",
" 3 | \n",
" 2 | \n",
" 14 | \n",
" 13 | \n",
" 4 | \n",
" 8 | \n",
"
\n",
" \n",
" 4 | \n",
" 43 | \n",
" Travel_Rarely | \n",
" 1273 | \n",
" Research & Development | \n",
" 2 | \n",
" 2 | \n",
" Medical | \n",
" 4 | \n",
" Female | \n",
" 72 | \n",
" ... | \n",
" 3 | \n",
" 4 | \n",
" 2 | \n",
" 6 | \n",
" 3 | \n",
" 2 | \n",
" 5 | \n",
" 3 | \n",
" 1 | \n",
" 4 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 135 | \n",
" 34 | \n",
" Travel_Rarely | \n",
" 704 | \n",
" Sales | \n",
" 28 | \n",
" 3 | \n",
" Marketing | \n",
" 4 | \n",
" Female | \n",
" 95 | \n",
" ... | \n",
" 4 | \n",
" 4 | \n",
" 2 | \n",
" 8 | \n",
" 2 | \n",
" 3 | \n",
" 8 | \n",
" 7 | \n",
" 1 | \n",
" 7 | \n",
"
\n",
" \n",
" 136 | \n",
" 36 | \n",
" Travel_Rarely | \n",
" 1120 | \n",
" Sales | \n",
" 11 | \n",
" 4 | \n",
" Marketing | \n",
" 2 | \n",
" Female | \n",
" 100 | \n",
" ... | \n",
" 3 | \n",
" 1 | \n",
" 1 | \n",
" 8 | \n",
" 2 | \n",
" 2 | \n",
" 6 | \n",
" 3 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 137 | \n",
" 29 | \n",
" Travel_Rarely | \n",
" 468 | \n",
" Research & Development | \n",
" 28 | \n",
" 4 | \n",
" Medical | \n",
" 4 | \n",
" Female | \n",
" 73 | \n",
" ... | \n",
" 3 | \n",
" 2 | \n",
" 0 | \n",
" 5 | \n",
" 3 | \n",
" 1 | \n",
" 5 | \n",
" 4 | \n",
" 0 | \n",
" 4 | \n",
"
\n",
" \n",
" 138 | \n",
" 39 | \n",
" Travel_Rarely | \n",
" 722 | \n",
" Sales | \n",
" 24 | \n",
" 1 | \n",
" Marketing | \n",
" 2 | \n",
" Female | \n",
" 60 | \n",
" ... | \n",
" 3 | \n",
" 1 | \n",
" 1 | \n",
" 21 | \n",
" 2 | \n",
" 2 | \n",
" 20 | \n",
" 9 | \n",
" 9 | \n",
" 6 | \n",
"
\n",
" \n",
" 139 | \n",
" 36 | \n",
" Travel_Frequently | \n",
" 884 | \n",
" Research & Development | \n",
" 23 | \n",
" 2 | \n",
" Medical | \n",
" 3 | \n",
" Male | \n",
" 41 | \n",
" ... | \n",
" 3 | \n",
" 3 | \n",
" 1 | \n",
" 17 | \n",
" 3 | \n",
" 3 | \n",
" 5 | \n",
" 2 | \n",
" 0 | \n",
" 3 | \n",
"
\n",
" \n",
"
\n",
"
140 rows × 31 columns
\n",
"
"
],
"text/plain": [
" Age BusinessTravel DailyRate Department \\\n",
"0 37 Travel_Rarely 1373 Research & Development \n",
"1 27 Travel_Rarely 591 Research & Development \n",
"2 32 Travel_Frequently 1005 Research & Development \n",
"3 53 Travel_Rarely 1282 Research & Development \n",
"4 43 Travel_Rarely 1273 Research & Development \n",
".. ... ... ... ... \n",
"135 34 Travel_Rarely 704 Sales \n",
"136 36 Travel_Rarely 1120 Sales \n",
"137 29 Travel_Rarely 468 Research & Development \n",
"138 39 Travel_Rarely 722 Sales \n",
"139 36 Travel_Frequently 884 Research & Development \n",
"\n",
" DistanceFromHome Education EducationField EnvironmentSatisfaction \\\n",
"0 2 2 Other 4 \n",
"1 2 1 Medical 1 \n",
"2 2 2 Life Sciences 4 \n",
"3 5 3 Other 3 \n",
"4 2 2 Medical 4 \n",
".. ... ... ... ... \n",
"135 28 3 Marketing 4 \n",
"136 11 4 Marketing 2 \n",
"137 28 4 Medical 4 \n",
"138 24 1 Marketing 2 \n",
"139 23 2 Medical 3 \n",
"\n",
" Gender HourlyRate ... PerformanceRating RelationshipSatisfaction \\\n",
"0 Male 92 ... 3 2 \n",
"1 Male 40 ... 3 4 \n",
"2 Male 79 ... 3 3 \n",
"3 Female 58 ... 3 4 \n",
"4 Female 72 ... 3 4 \n",
".. ... ... ... ... ... \n",
"135 Female 95 ... 4 4 \n",
"136 Female 100 ... 3 1 \n",
"137 Female 73 ... 3 2 \n",
"138 Female 60 ... 3 1 \n",
"139 Male 41 ... 3 3 \n",
"\n",
" StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n",
"0 0 7 3 \n",
"1 1 6 3 \n",
"2 0 8 2 \n",
"3 1 26 3 \n",
"4 2 6 3 \n",
".. ... ... ... \n",
"135 2 8 2 \n",
"136 1 8 2 \n",
"137 0 5 3 \n",
"138 1 21 2 \n",
"139 1 17 3 \n",
"\n",
" WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n",
"0 3 0 0 \n",
"1 3 2 2 \n",
"2 2 7 7 \n",
"3 2 14 13 \n",
"4 2 5 3 \n",
".. ... ... ... \n",
"135 3 8 7 \n",
"136 2 6 3 \n",
"137 1 5 4 \n",
"138 2 20 9 \n",
"139 3 5 2 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager \n",
"0 0 0 \n",
"1 2 2 \n",
"2 3 6 \n",
"3 4 8 \n",
"4 1 4 \n",
".. ... ... \n",
"135 1 7 \n",
"136 0 0 \n",
"137 0 4 \n",
"138 9 6 \n",
"139 0 3 \n",
"\n",
"[140 rows x 31 columns]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Remove excluded features from the validation dataset\n",
"validation = validation[features.index.tolist()]\n",
"validation"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8785714285714286"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get a score for the validation dataset from the saved pipeline\n",
"model.score(validation, v_target.values.ravel())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Train a Keras model:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" BusinessTravel_Non-Travel | \n",
" BusinessTravel_Travel_Frequently | \n",
" BusinessTravel_Travel_Rarely | \n",
" Department_Human Resources | \n",
" Department_Research & Development | \n",
" Department_Sales | \n",
" Education_1 | \n",
" Education_2 | \n",
" Education_3 | \n",
" Education_4 | \n",
" ... | \n",
" TotalWorkingYears | \n",
" TrainingTimesLastYear | \n",
" YearsAtCompany | \n",
" YearsInCurrentRole | \n",
" YearsSinceLastPromotion | \n",
" YearsWithCurrManager | \n",
" JobRole0 | \n",
" JobRole1 | \n",
" JobRole2 | \n",
" JobRole3 | \n",
"
\n",
" \n",
" \n",
" \n",
" 1298 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 0.591154 | \n",
" -0.621514 | \n",
" 1.462919 | \n",
" 1.913468 | \n",
" 1.143272 | \n",
" 1.058170 | \n",
" 0.512184 | \n",
" 0.799463 | \n",
" 0.058274 | \n",
" 0.899445 | \n",
"
\n",
" \n",
" 620 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" ... | \n",
" 0.210630 | \n",
" -0.621514 | \n",
" -1.131942 | \n",
" -1.175408 | \n",
" -0.678787 | \n",
" -1.160026 | \n",
" -0.994429 | \n",
" -0.451462 | \n",
" -0.816773 | \n",
" -1.706519 | \n",
"
\n",
" \n",
" 1193 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" ... | \n",
" 0.464312 | \n",
" -0.621514 | \n",
" -0.969763 | \n",
" -1.175408 | \n",
" -0.678787 | \n",
" -1.160026 | \n",
" -0.994429 | \n",
" -0.451462 | \n",
" -0.816773 | \n",
" -1.706519 | \n",
"
\n",
" \n",
" 139 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" -0.296735 | \n",
" 0.153090 | \n",
" -0.969763 | \n",
" -1.175408 | \n",
" -0.678787 | \n",
" -1.160026 | \n",
" -0.994429 | \n",
" -0.451462 | \n",
" -0.816773 | \n",
" -1.706519 | \n",
"
\n",
" \n",
" 1165 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" ... | \n",
" -0.296735 | \n",
" 0.153090 | \n",
" -0.483227 | \n",
" -0.332987 | \n",
" -0.071434 | \n",
" -0.605477 | \n",
" 0.512184 | \n",
" 0.799463 | \n",
" 0.058274 | \n",
" 0.899445 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 74 columns
\n",
"
"
],
"text/plain": [
" BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n",
"1298 0 0 \n",
"620 0 0 \n",
"1193 0 0 \n",
"139 0 0 \n",
"1165 0 0 \n",
"\n",
" BusinessTravel_Travel_Rarely Department_Human Resources \\\n",
"1298 1 0 \n",
"620 1 0 \n",
"1193 1 0 \n",
"139 1 0 \n",
"1165 1 0 \n",
"\n",
" Department_Research & Development Department_Sales Education_1 \\\n",
"1298 1 0 0 \n",
"620 1 0 0 \n",
"1193 1 0 0 \n",
"139 1 0 0 \n",
"1165 1 0 1 \n",
"\n",
" Education_2 Education_3 Education_4 ... TotalWorkingYears \\\n",
"1298 0 1 0 ... 0.591154 \n",
"620 0 0 1 ... 0.210630 \n",
"1193 0 0 1 ... 0.464312 \n",
"139 0 1 0 ... -0.296735 \n",
"1165 0 0 0 ... -0.296735 \n",
"\n",
" TrainingTimesLastYear YearsAtCompany YearsInCurrentRole \\\n",
"1298 -0.621514 1.462919 1.913468 \n",
"620 -0.621514 -1.131942 -1.175408 \n",
"1193 -0.621514 -0.969763 -1.175408 \n",
"139 0.153090 -0.969763 -1.175408 \n",
"1165 0.153090 -0.483227 -0.332987 \n",
"\n",
" YearsSinceLastPromotion YearsWithCurrManager JobRole0 JobRole1 \\\n",
"1298 1.143272 1.058170 0.512184 0.799463 \n",
"620 -0.678787 -1.160026 -0.994429 -0.451462 \n",
"1193 -0.678787 -1.160026 -0.994429 -0.451462 \n",
"139 -0.678787 -1.160026 -0.994429 -0.451462 \n",
"1165 -0.071434 -0.605477 0.512184 0.799463 \n",
"\n",
" JobRole2 JobRole3 \n",
"1298 0.058274 0.899445 \n",
"620 -0.816773 -1.706519 \n",
"1193 -0.816773 -1.706519 \n",
"139 -0.816773 -1.706519 \n",
"1165 0.058274 0.899445 \n",
"\n",
"[5 rows x 74 columns]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Run the training data through the preprocessor\n",
"X_train_transformed = best_pipe.named_steps['prep'].transform(X_train)\n",
"X_train_transformed.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0])"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# Encode target values\n",
"le = LabelEncoder().fit(y_train.values.ravel())\n",
"y_train_encoded = le.transform(y_train.values.ravel())\n",
"y_train_encoded[:10]"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"931/931 [==============================] - 0s 498us/step - loss: 0.2025 - accuracy: 0.2266\n",
"Epoch 2/50\n",
"931/931 [==============================] - 0s 165us/step - loss: 0.1567 - accuracy: 0.3373\n",
"Epoch 3/50\n",
"931/931 [==============================] - 0s 178us/step - loss: 0.1286 - accuracy: 0.5800\n",
"Epoch 4/50\n",
"931/931 [==============================] - 0s 165us/step - loss: 0.1126 - accuracy: 0.6273\n",
"Epoch 5/50\n",
"931/931 [==============================] - 0s 167us/step - loss: 0.0951 - accuracy: 0.7132\n",
"Epoch 6/50\n",
"931/931 [==============================] - 0s 164us/step - loss: 0.0796 - accuracy: 0.7444\n",
"Epoch 7/50\n",
"931/931 [==============================] - 0s 155us/step - loss: 0.0626 - accuracy: 0.8163\n",
"Epoch 8/50\n",
"931/931 [==============================] - 0s 187us/step - loss: 0.0501 - accuracy: 0.8528\n",
"Epoch 9/50\n",
"931/931 [==============================] - 0s 202us/step - loss: 0.0432 - accuracy: 0.8937\n",
"Epoch 10/50\n",
"931/931 [==============================] - 0s 216us/step - loss: 0.0341 - accuracy: 0.9151\n",
"Epoch 11/50\n",
"931/931 [==============================] - 0s 209us/step - loss: 0.0302 - accuracy: 0.9323\n",
"Epoch 12/50\n",
"931/931 [==============================] - 0s 201us/step - loss: 0.0207 - accuracy: 0.9538\n",
"Epoch 13/50\n",
"931/931 [==============================] - 0s 194us/step - loss: 0.0166 - accuracy: 0.9646\n",
"Epoch 14/50\n",
"931/931 [==============================] - 0s 215us/step - loss: 0.0134 - accuracy: 0.9678\n",
"Epoch 15/50\n",
"931/931 [==============================] - 0s 178us/step - loss: 0.0094 - accuracy: 0.9817\n",
"Epoch 16/50\n",
"931/931 [==============================] - 0s 159us/step - loss: 0.0072 - accuracy: 0.9903\n",
"Epoch 17/50\n",
"931/931 [==============================] - 0s 162us/step - loss: 0.0052 - accuracy: 0.9936\n",
"Epoch 18/50\n",
"931/931 [==============================] - 0s 162us/step - loss: 0.0041 - accuracy: 0.9946\n",
"Epoch 19/50\n",
"931/931 [==============================] - 0s 160us/step - loss: 0.0132 - accuracy: 0.9774\n",
"Epoch 20/50\n",
"931/931 [==============================] - 0s 179us/step - loss: 0.0167 - accuracy: 0.9656\n",
"Epoch 21/50\n",
"931/931 [==============================] - 0s 185us/step - loss: 0.0070 - accuracy: 0.9850\n",
"Epoch 22/50\n",
"931/931 [==============================] - 0s 159us/step - loss: 0.0041 - accuracy: 0.9936\n",
"Epoch 23/50\n",
"931/931 [==============================] - 0s 162us/step - loss: 0.0019 - accuracy: 0.9989\n",
"Epoch 24/50\n",
"931/931 [==============================] - 0s 165us/step - loss: 0.0014 - accuracy: 0.9989\n",
"Epoch 25/50\n",
"931/931 [==============================] - 0s 207us/step - loss: 0.0010 - accuracy: 1.0000\n",
"Epoch 26/50\n",
"931/931 [==============================] - 0s 161us/step - loss: 8.1854e-04 - accuracy: 1.0000\n",
"Epoch 27/50\n",
"931/931 [==============================] - 0s 163us/step - loss: 6.7515e-04 - accuracy: 1.0000\n",
"Epoch 28/50\n",
"931/931 [==============================] - 0s 172us/step - loss: 5.9500e-04 - accuracy: 1.0000\n",
"Epoch 29/50\n",
"931/931 [==============================] - 0s 180us/step - loss: 5.1935e-04 - accuracy: 1.0000\n",
"Epoch 30/50\n",
"931/931 [==============================] - 0s 197us/step - loss: 4.4537e-04 - accuracy: 1.0000\n",
"Epoch 31/50\n",
"931/931 [==============================] - 0s 194us/step - loss: 3.9470e-04 - accuracy: 1.0000\n",
"Epoch 32/50\n",
"931/931 [==============================] - 0s 191us/step - loss: 3.5023e-04 - accuracy: 1.0000\n",
"Epoch 33/50\n",
"931/931 [==============================] - 0s 196us/step - loss: 3.1019e-04 - accuracy: 1.0000\n",
"Epoch 34/50\n",
"931/931 [==============================] - 0s 175us/step - loss: 2.7797e-04 - accuracy: 1.0000\n",
"Epoch 35/50\n",
"931/931 [==============================] - 0s 172us/step - loss: 2.5035e-04 - accuracy: 1.0000\n",
"Epoch 36/50\n",
"931/931 [==============================] - 0s 163us/step - loss: 2.2762e-04 - accuracy: 1.0000\n",
"Epoch 37/50\n",
"931/931 [==============================] - 0s 161us/step - loss: 2.0539e-04 - accuracy: 1.0000\n",
"Epoch 38/50\n",
"931/931 [==============================] - 0s 172us/step - loss: 1.8716e-04 - accuracy: 1.0000\n",
"Epoch 39/50\n",
"931/931 [==============================] - 0s 164us/step - loss: 1.6805e-04 - accuracy: 1.0000\n",
"Epoch 40/50\n",
"931/931 [==============================] - 0s 181us/step - loss: 1.5399e-04 - accuracy: 1.0000\n",
"Epoch 41/50\n",
"931/931 [==============================] - 0s 170us/step - loss: 1.3869e-04 - accuracy: 1.0000\n",
"Epoch 42/50\n",
"931/931 [==============================] - 0s 175us/step - loss: 1.2607e-04 - accuracy: 1.0000\n",
"Epoch 43/50\n",
"931/931 [==============================] - 0s 176us/step - loss: 1.1507e-04 - accuracy: 1.0000\n",
"Epoch 44/50\n",
"931/931 [==============================] - 0s 175us/step - loss: 1.0526e-04 - accuracy: 1.0000\n",
"Epoch 45/50\n",
"931/931 [==============================] - 0s 165us/step - loss: 9.6587e-05 - accuracy: 1.0000\n",
"Epoch 46/50\n",
"931/931 [==============================] - 0s 180us/step - loss: 8.8015e-05 - accuracy: 1.0000\n",
"Epoch 47/50\n",
"931/931 [==============================] - 0s 175us/step - loss: 8.1163e-05 - accuracy: 1.0000\n",
"Epoch 48/50\n",
"931/931 [==============================] - 0s 185us/step - loss: 7.3699e-05 - accuracy: 1.0000\n",
"Epoch 49/50\n",
"931/931 [==============================] - 0s 174us/step - loss: 6.7770e-05 - accuracy: 1.0000\n",
"Epoch 50/50\n",
"931/931 [==============================] - 0s 170us/step - loss: 6.2085e-05 - accuracy: 1.0000\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"\n",
"# Define the Keras model\n",
"model = Sequential()\n",
"model.add(Dense(100, input_dim=74, activation='relu'))\n",
"model.add(Dense(50, activation='relu'))\n",
"model.add(Dense(1, activation='sigmoid'))\n",
"\n",
"# Compile the Keras model\n",
"model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
"\n",
"# Fit the Keras model on the dataset\n",
"model.fit(X_train_transformed, y_train_encoded, epochs=50, batch_size=8, class_weight={0:0.1, 1:2.0})"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"399/399 [==============================] - 0s 35us/step\n",
"Keras test accuracy: 0.877\n"
]
}
],
"source": [
"# Run the test data through the preprocessor\n",
"X_test_transformed = best_pipe.named_steps['prep'].transform(X_test)\n",
"# Encode the test labels\n",
"y_test_encoded = le.transform(y_test.values.ravel())\n",
"\n",
"# Check model accuracy on test data\n",
"print('Keras test accuracy: %.3f' % (model.evaluate(X_test_transformed, y_test_encoded)[1]))"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"# Save the keras model architecture and weights to disk\n",
"model.save('HR-Attrition-Keras-v1.h5')"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Keras prediction: 0\n"
]
}
],
"source": [
"import keras\n",
"from keras import backend as kerasbackend\n",
"\n",
"kerasbackend.clear_session()\n",
" \n",
"# Load the keras model architecture and weights from disk\n",
"keras_model = keras.models.load_model('HR-Attrition-Keras-v1.h5')\n",
"keras_model._make_predict_function()\n",
"\n",
"print('Keras prediction: %.0f' % (keras_model.predict(X_test_transformed.iloc[[0]])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}