{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Complete flow to generate a ML model for the HR Attrition dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Prepare the data:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Import a custom transformer for preprocessing data based on feature definitions\n", "from preprocessor import Preprocessor" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
samplevariable_typedata_typefeature_strategyhash_features
name
Age41featureintscalingNaN
AttritionYestargetstrnoneNaN
BusinessTravelTravel_Rarelyfeaturestrone hot encodingNaN
DailyRate1102featureintscalingNaN
DepartmentSalesfeaturestrone hot encodingNaN
\n", "
" ], "text/plain": [ " sample variable_type data_type feature_strategy \\\n", "name \n", "Age 41 feature int scaling \n", "Attrition Yes target str none \n", "BusinessTravel Travel_Rarely feature str one hot encoding \n", "DailyRate 1102 feature int scaling \n", "Department Sales feature str one hot encoding \n", "\n", " hash_features \n", "name \n", "Age NaN \n", "Attrition NaN \n", "BusinessTravel NaN \n", "DailyRate NaN \n", "Department NaN " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "# Import feature definitions and data\n", "sheets = pd.read_excel('Data/HR-Employee-Attrition.xlsx', sheet_name=[\"Feature Definitions\", \"Train-Test\"])\n", "\n", "# Create feature definitions data frame\n", "features = sheets[\"Feature Definitions\"]\n", "features.columns = [c.lower() for c in features.columns]\n", "features.set_index(\"name\", append=False, inplace=True)\n", "features.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Attrition
0Yes
1No
2No
3No
4No
\n", "
" ], "text/plain": [ " Attrition\n", "0 Yes\n", "1 No\n", "2 No\n", "3 No\n", "4 No" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Setup the data dataframe\n", "data = sheets[\"Train-Test\"]\n", "\n", "# Get the target features\n", "target = features.loc[features[\"variable_type\"] == \"target\"]\n", "target_name = target.index[0]\n", "\n", "# Get the target data\n", "d_target = data.loc[:,[target_name]]\n", "\n", "d_target.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Get the features to be excluded from the model\n", "exclusions = features['variable_type'].isin([\"excluded\", \"target\", \"identifier\"])\n", "\n", "excluded = features.loc[exclusions]\n", "features = features.loc[~exclusions]\n", "\n", "# Remove excluded features from the data\n", "data = data[features.index.tolist()]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Split the data into training and testing subsets\n", "from sklearn.model_selection import train_test_split\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(data, d_target, test_size=0.30, random_state=42)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeBusinessTravelDailyRateDepartmentDistanceFromHomeEducationEducationFieldEnvironmentSatisfactionGenderHourlyRate...PerformanceRatingRelationshipSatisfactionStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
89341Travel_Frequently1200Research & Development223Life Sciences4Female75...31212426233
11550Travel_Frequently809Sales123Marketing3Female77...34016332221
52627Travel_Frequently829Sales81Marketing3Male84...3215334211
17535Travel_Frequently138Research & Development23Medical2Female37...34010536212
6345Travel_Rarely193Research & Development64Other4Male52...32017340000
\n", "

5 rows × 31 columns

\n", "
" ], "text/plain": [ " Age BusinessTravel DailyRate Department \\\n", "893 41 Travel_Frequently 1200 Research & Development \n", "115 50 Travel_Frequently 809 Sales \n", "526 27 Travel_Frequently 829 Sales \n", "175 35 Travel_Frequently 138 Research & Development \n", "63 45 Travel_Rarely 193 Research & Development \n", "\n", " DistanceFromHome Education EducationField EnvironmentSatisfaction \\\n", "893 22 3 Life Sciences 4 \n", "115 12 3 Marketing 3 \n", "526 8 1 Marketing 3 \n", "175 2 3 Medical 2 \n", "63 6 4 Other 4 \n", "\n", " Gender HourlyRate ... PerformanceRating RelationshipSatisfaction \\\n", "893 Female 75 ... 3 1 \n", "115 Female 77 ... 3 4 \n", "526 Male 84 ... 3 2 \n", "175 Female 37 ... 3 4 \n", "63 Male 52 ... 3 2 \n", "\n", " StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n", "893 2 12 4 \n", "115 0 16 3 \n", "526 1 5 3 \n", "175 0 10 5 \n", "63 0 17 3 \n", "\n", " WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "893 2 6 2 \n", "115 3 2 2 \n", "526 3 4 2 \n", "175 3 6 2 \n", "63 4 0 0 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager \n", "893 3 3 \n", "115 2 1 \n", "526 1 1 \n", "175 1 2 \n", "63 0 0 \n", "\n", "[5 rows x 31 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
BusinessTravel_Non-TravelBusinessTravel_Travel_FrequentlyBusinessTravel_Travel_RarelyDepartment_Human ResourcesDepartment_Research & DevelopmentDepartment_SalesEducation_1Education_2Education_3Education_4...TotalWorkingYearsTrainingTimesLastYearYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManagerJobRole0JobRole1JobRole2JobRole3
8930100100010...0.0837890.927695-0.158869-0.6137940.232242-0.328203-0.994429-0.451462-0.816773-1.706519
1150100010010...0.5911540.153090-0.807584-0.613794-0.071434-0.8827521.265490-0.034487-0.8167730.899445
5260100011000...-0.8041000.153090-0.483227-0.613794-0.375111-0.8827521.265490-0.034487-0.8167730.899445
1750100100010...-0.1698941.702300-0.158869-0.613794-0.375111-0.605477-0.994429-1.7023871.8083670.899445
630010100001...0.7179950.153090-1.131942-1.175408-0.678787-1.1600260.5121840.7994630.0582740.899445
..................................................................
3810011000100...-1.3114640.153090-0.969763-1.175408-0.678787-1.1600262.0187961.633413-0.8167730.030790
2920010010010...0.2106300.1530900.0033100.790241-0.3751110.7808961.265490-0.034487-0.8167730.899445
10830010011000...-1.1846230.153090-0.807584-0.613794-0.678787-0.605477-0.2411231.216438-0.816773-0.837864
5510010100010...-1.3114640.153090-0.969763-1.175408-0.678787-1.160026-0.994429-0.451462-0.816773-1.706519
1410100010001...-0.6772580.153090-0.483227-0.332987-0.375111-0.605477-0.2411231.216438-0.816773-0.837864
\n", "

399 rows × 74 columns

\n", "
" ], "text/plain": [ " BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n", "893 0 1 \n", "115 0 1 \n", "526 0 1 \n", "175 0 1 \n", "63 0 0 \n", "... ... ... \n", "381 0 0 \n", "292 0 0 \n", "1083 0 0 \n", "551 0 0 \n", "141 0 1 \n", "\n", " BusinessTravel_Travel_Rarely Department_Human Resources \\\n", "893 0 0 \n", "115 0 0 \n", "526 0 0 \n", "175 0 0 \n", "63 1 0 \n", "... ... ... \n", "381 1 1 \n", "292 1 0 \n", "1083 1 0 \n", "551 1 0 \n", "141 0 0 \n", "\n", " Department_Research & Development Department_Sales Education_1 \\\n", "893 1 0 0 \n", "115 0 1 0 \n", "526 0 1 1 \n", "175 1 0 0 \n", "63 1 0 0 \n", "... ... ... ... \n", "381 0 0 0 \n", "292 0 1 0 \n", "1083 0 1 1 \n", "551 1 0 0 \n", "141 0 1 0 \n", "\n", " Education_2 Education_3 Education_4 ... TotalWorkingYears \\\n", "893 0 1 0 ... 0.083789 \n", "115 0 1 0 ... 0.591154 \n", "526 0 0 0 ... -0.804100 \n", "175 0 1 0 ... -0.169894 \n", "63 0 0 1 ... 0.717995 \n", "... ... ... ... ... ... \n", "381 1 0 0 ... -1.311464 \n", "292 0 1 0 ... 0.210630 \n", "1083 0 0 0 ... -1.184623 \n", "551 0 1 0 ... -1.311464 \n", "141 0 0 1 ... -0.677258 \n", "\n", " TrainingTimesLastYear YearsAtCompany YearsInCurrentRole \\\n", "893 0.927695 -0.158869 -0.613794 \n", "115 0.153090 -0.807584 -0.613794 \n", "526 0.153090 -0.483227 -0.613794 \n", "175 1.702300 -0.158869 -0.613794 \n", "63 0.153090 -1.131942 -1.175408 \n", "... ... ... ... \n", "381 0.153090 -0.969763 -1.175408 \n", "292 0.153090 0.003310 0.790241 \n", "1083 0.153090 -0.807584 -0.613794 \n", "551 0.153090 -0.969763 -1.175408 \n", "141 0.153090 -0.483227 -0.332987 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager JobRole0 JobRole1 \\\n", "893 0.232242 -0.328203 -0.994429 -0.451462 \n", "115 -0.071434 -0.882752 1.265490 -0.034487 \n", "526 -0.375111 -0.882752 1.265490 -0.034487 \n", "175 -0.375111 -0.605477 -0.994429 -1.702387 \n", "63 -0.678787 -1.160026 0.512184 0.799463 \n", "... ... ... ... ... \n", "381 -0.678787 -1.160026 2.018796 1.633413 \n", "292 -0.375111 0.780896 1.265490 -0.034487 \n", "1083 -0.678787 -0.605477 -0.241123 1.216438 \n", "551 -0.678787 -1.160026 -0.994429 -0.451462 \n", "141 -0.375111 -0.605477 -0.241123 1.216438 \n", "\n", " JobRole2 JobRole3 \n", "893 -0.816773 -1.706519 \n", "115 -0.816773 0.899445 \n", "526 -0.816773 0.899445 \n", "175 1.808367 0.899445 \n", "63 0.058274 0.899445 \n", "... ... ... \n", "381 -0.816773 0.030790 \n", "292 -0.816773 0.899445 \n", "1083 -0.816773 -0.837864 \n", "551 -0.816773 -1.706519 \n", "141 -0.816773 -0.837864 \n", "\n", "[399 rows x 74 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test the preprocessor\n", "prep = Preprocessor(features, return_type='df').fit(X_train)\n", "prep.transform(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Set up machine learning pipelines, fit and score the models:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Logistic Regression pipeline test accuracy: 0.905\n", "Random Forest pipeline test accuracy: 0.862\n", "Classifier with best accuracy: Logistic Regression\n", "Saved Logistic Regression pipeline to file\n" ] } ], "source": [ "# Set up a pipeline to run the ML flow\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier\n", "import pickle\n", "\n", "# Construct the pipelines\n", "pipe_lr = Pipeline([('prep', Preprocessor(features, return_type='df')), ('clf', LogisticRegression(solver='lbfgs', random_state=42))])\n", "pipe_rf = Pipeline([('prep', Preprocessor(features, return_type='df')), ('clf', RandomForestClassifier(n_estimators=10, random_state=42))])\n", "\n", "# List of pipelines for ease of iteration\n", "pipelines = [pipe_lr, pipe_rf]\n", "\n", "# Dictionary of pipelines and classifier types for ease of reference\n", "pipe_dict = {0: 'Logistic Regression', 1: 'Random Forest'}\n", "\n", "# Fit the pipelines\n", "for pipe in pipelines:\n", " pipe.fit(X_train, y_train.values.ravel())\n", "\n", "# Compare accuracies\n", "for idx, val in enumerate(pipelines):\n", " print('%s pipeline test accuracy: %.3f' % (pipe_dict[idx], val.score(X_test, y_test)))\n", "\n", "# Identify the most accurate model on test data\n", "best_acc = 0.0\n", "best_clf = 0\n", "best_pipe = ''\n", "for idx, val in enumerate(pipelines):\n", " if val.score(X_test, y_test) > best_acc:\n", " best_acc = val.score(X_test, y_test.values.ravel())\n", " best_pipe = val\n", " best_clf = idx\n", "print('Classifier with best accuracy: %s' % pipe_dict[best_clf])\n", "\n", "# Save pipeline to file\n", "with open('HR-Attrition-v1.pkl', 'wb') as file:\n", " pickle.dump(best_pipe, file)\n", " print('Saved %s pipeline to file' % pipe_dict[best_clf])\n", "\n", "# Also save the preprocessor and model as separate files\n", "with open('HR-Attrition-v1-prep.pkl', 'wb') as file:\n", " pickle.dump(best_pipe.named_steps['prep'], file)\n", "with open('HR-Attrition-v1-clf.pkl', 'wb') as file:\n", " pickle.dump(best_pipe.named_steps['clf'], file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Validate the saved model:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [ { "data": { "text/plain": [ "Pipeline(memory=None,\n", " steps=[('prep',\n", " ),\n", " ('clf',\n", " LogisticRegression(C=1.0, class_weight=None, dual=False,\n", " fit_intercept=True, intercept_scaling=1,\n", " l1_ratio=None, max_iter=100,\n", " multi_class='warn', n_jobs=None,\n", " penalty='l2', random_state=42,\n", " solver='lbfgs', tol=0.0001, verbose=0,\n", " warm_start=False))],\n", " verbose=False)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the saved pipeline from disk\n", "with open('HR-Attrition-v1.pkl', 'rb') as file:\n", " model = pickle.load(file)\n", "\n", "model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeAttritionBusinessTravelDailyRateDepartmentDistanceFromHomeEducationEducationFieldEmployeeCountEmployeeNumber...RelationshipSatisfactionStandardHoursStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
037YesTravel_Rarely1373Research & Development22Other14...28007330000
127NoTravel_Rarely591Research & Development21Medical17...48016332222
232NoTravel_Frequently1005Research & Development22Life Sciences18...38008227736
353NoTravel_Rarely1282Research & Development53Other132...48012632141348
443NoTravel_Rarely1273Research & Development22Medical146...48026325314
..................................................................
13534NoTravel_Rarely704Sales283Marketing12035...48028238717
13636NoTravel_Rarely1120Sales114Marketing12045...18018226300
13729NoTravel_Rarely468Research & Development284Medical12054...28005315404
13839NoTravel_Rarely722Sales241Marketing12056...1801212220996
13936NoTravel_Frequently884Research & Development232Medical12061...380117335203
\n", "

140 rows × 35 columns

\n", "
" ], "text/plain": [ " Age Attrition BusinessTravel DailyRate Department \\\n", "0 37 Yes Travel_Rarely 1373 Research & Development \n", "1 27 No Travel_Rarely 591 Research & Development \n", "2 32 No Travel_Frequently 1005 Research & Development \n", "3 53 No Travel_Rarely 1282 Research & Development \n", "4 43 No Travel_Rarely 1273 Research & Development \n", ".. ... ... ... ... ... \n", "135 34 No Travel_Rarely 704 Sales \n", "136 36 No Travel_Rarely 1120 Sales \n", "137 29 No Travel_Rarely 468 Research & Development \n", "138 39 No Travel_Rarely 722 Sales \n", "139 36 No Travel_Frequently 884 Research & Development \n", "\n", " DistanceFromHome Education EducationField EmployeeCount \\\n", "0 2 2 Other 1 \n", "1 2 1 Medical 1 \n", "2 2 2 Life Sciences 1 \n", "3 5 3 Other 1 \n", "4 2 2 Medical 1 \n", ".. ... ... ... ... \n", "135 28 3 Marketing 1 \n", "136 11 4 Marketing 1 \n", "137 28 4 Medical 1 \n", "138 24 1 Marketing 1 \n", "139 23 2 Medical 1 \n", "\n", " EmployeeNumber ... RelationshipSatisfaction StandardHours \\\n", "0 4 ... 2 80 \n", "1 7 ... 4 80 \n", "2 8 ... 3 80 \n", "3 32 ... 4 80 \n", "4 46 ... 4 80 \n", ".. ... ... ... ... \n", "135 2035 ... 4 80 \n", "136 2045 ... 1 80 \n", "137 2054 ... 2 80 \n", "138 2056 ... 1 80 \n", "139 2061 ... 3 80 \n", "\n", " StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n", "0 0 7 3 \n", "1 1 6 3 \n", "2 0 8 2 \n", "3 1 26 3 \n", "4 2 6 3 \n", ".. ... ... ... \n", "135 2 8 2 \n", "136 1 8 2 \n", "137 0 5 3 \n", "138 1 21 2 \n", "139 1 17 3 \n", "\n", " WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "0 3 0 0 \n", "1 3 2 2 \n", "2 2 7 7 \n", "3 2 14 13 \n", "4 2 5 3 \n", ".. ... ... ... \n", "135 3 8 7 \n", "136 2 6 3 \n", "137 1 5 4 \n", "138 2 20 9 \n", "139 3 5 2 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager \n", "0 0 0 \n", "1 2 2 \n", "2 3 6 \n", "3 4 8 \n", "4 1 4 \n", ".. ... ... \n", "135 1 7 \n", "136 0 0 \n", "137 0 4 \n", "138 9 6 \n", "139 0 3 \n", "\n", "[140 rows x 35 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load additional data to test the saved model\n", "validation = pd.read_excel('Data/HR-Employee-Attrition.xlsx', sheet_name=\"Validate\")\n", "validation" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Attrition
0Yes
1No
2No
3No
4No
\n", "
" ], "text/plain": [ " Attrition\n", "0 Yes\n", "1 No\n", "2 No\n", "3 No\n", "4 No" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the targets\n", "v_target = validation.loc[:,[target_name]]\n", "v_target.head()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeBusinessTravelDailyRateDepartmentDistanceFromHomeEducationEducationFieldEnvironmentSatisfactionGenderHourlyRate...PerformanceRatingRelationshipSatisfactionStockOptionLevelTotalWorkingYearsTrainingTimesLastYearWorkLifeBalanceYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManager
037Travel_Rarely1373Research & Development22Other4Male92...3207330000
127Travel_Rarely591Research & Development21Medical1Male40...3416332222
232Travel_Frequently1005Research & Development22Life Sciences4Male79...3308227736
353Travel_Rarely1282Research & Development53Other3Female58...3412632141348
443Travel_Rarely1273Research & Development22Medical4Female72...3426325314
..................................................................
13534Travel_Rarely704Sales283Marketing4Female95...4428238717
13636Travel_Rarely1120Sales114Marketing2Female100...3118226300
13729Travel_Rarely468Research & Development284Medical4Female73...3205315404
13839Travel_Rarely722Sales241Marketing2Female60...311212220996
13936Travel_Frequently884Research & Development232Medical3Male41...33117335203
\n", "

140 rows × 31 columns

\n", "
" ], "text/plain": [ " Age BusinessTravel DailyRate Department \\\n", "0 37 Travel_Rarely 1373 Research & Development \n", "1 27 Travel_Rarely 591 Research & Development \n", "2 32 Travel_Frequently 1005 Research & Development \n", "3 53 Travel_Rarely 1282 Research & Development \n", "4 43 Travel_Rarely 1273 Research & Development \n", ".. ... ... ... ... \n", "135 34 Travel_Rarely 704 Sales \n", "136 36 Travel_Rarely 1120 Sales \n", "137 29 Travel_Rarely 468 Research & Development \n", "138 39 Travel_Rarely 722 Sales \n", "139 36 Travel_Frequently 884 Research & Development \n", "\n", " DistanceFromHome Education EducationField EnvironmentSatisfaction \\\n", "0 2 2 Other 4 \n", "1 2 1 Medical 1 \n", "2 2 2 Life Sciences 4 \n", "3 5 3 Other 3 \n", "4 2 2 Medical 4 \n", ".. ... ... ... ... \n", "135 28 3 Marketing 4 \n", "136 11 4 Marketing 2 \n", "137 28 4 Medical 4 \n", "138 24 1 Marketing 2 \n", "139 23 2 Medical 3 \n", "\n", " Gender HourlyRate ... PerformanceRating RelationshipSatisfaction \\\n", "0 Male 92 ... 3 2 \n", "1 Male 40 ... 3 4 \n", "2 Male 79 ... 3 3 \n", "3 Female 58 ... 3 4 \n", "4 Female 72 ... 3 4 \n", ".. ... ... ... ... ... \n", "135 Female 95 ... 4 4 \n", "136 Female 100 ... 3 1 \n", "137 Female 73 ... 3 2 \n", "138 Female 60 ... 3 1 \n", "139 Male 41 ... 3 3 \n", "\n", " StockOptionLevel TotalWorkingYears TrainingTimesLastYear \\\n", "0 0 7 3 \n", "1 1 6 3 \n", "2 0 8 2 \n", "3 1 26 3 \n", "4 2 6 3 \n", ".. ... ... ... \n", "135 2 8 2 \n", "136 1 8 2 \n", "137 0 5 3 \n", "138 1 21 2 \n", "139 1 17 3 \n", "\n", " WorkLifeBalance YearsAtCompany YearsInCurrentRole \\\n", "0 3 0 0 \n", "1 3 2 2 \n", "2 2 7 7 \n", "3 2 14 13 \n", "4 2 5 3 \n", ".. ... ... ... \n", "135 3 8 7 \n", "136 2 6 3 \n", "137 1 5 4 \n", "138 2 20 9 \n", "139 3 5 2 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager \n", "0 0 0 \n", "1 2 2 \n", "2 3 6 \n", "3 4 8 \n", "4 1 4 \n", ".. ... ... \n", "135 1 7 \n", "136 0 0 \n", "137 0 4 \n", "138 9 6 \n", "139 0 3 \n", "\n", "[140 rows x 31 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Remove excluded features from the validation dataset\n", "validation = validation[features.index.tolist()]\n", "validation" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8785714285714286" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get a score for the validation dataset from the saved pipeline\n", "model.score(validation, v_target.values.ravel())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train a Keras model:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
BusinessTravel_Non-TravelBusinessTravel_Travel_FrequentlyBusinessTravel_Travel_RarelyDepartment_Human ResourcesDepartment_Research & DevelopmentDepartment_SalesEducation_1Education_2Education_3Education_4...TotalWorkingYearsTrainingTimesLastYearYearsAtCompanyYearsInCurrentRoleYearsSinceLastPromotionYearsWithCurrManagerJobRole0JobRole1JobRole2JobRole3
12980010100010...0.591154-0.6215141.4629191.9134681.1432721.0581700.5121840.7994630.0582740.899445
6200010100001...0.210630-0.621514-1.131942-1.175408-0.678787-1.160026-0.994429-0.451462-0.816773-1.706519
11930010100001...0.464312-0.621514-0.969763-1.175408-0.678787-1.160026-0.994429-0.451462-0.816773-1.706519
1390010100010...-0.2967350.153090-0.969763-1.175408-0.678787-1.160026-0.994429-0.451462-0.816773-1.706519
11650010101000...-0.2967350.153090-0.483227-0.332987-0.071434-0.6054770.5121840.7994630.0582740.899445
\n", "

5 rows × 74 columns

\n", "
" ], "text/plain": [ " BusinessTravel_Non-Travel BusinessTravel_Travel_Frequently \\\n", "1298 0 0 \n", "620 0 0 \n", "1193 0 0 \n", "139 0 0 \n", "1165 0 0 \n", "\n", " BusinessTravel_Travel_Rarely Department_Human Resources \\\n", "1298 1 0 \n", "620 1 0 \n", "1193 1 0 \n", "139 1 0 \n", "1165 1 0 \n", "\n", " Department_Research & Development Department_Sales Education_1 \\\n", "1298 1 0 0 \n", "620 1 0 0 \n", "1193 1 0 0 \n", "139 1 0 0 \n", "1165 1 0 1 \n", "\n", " Education_2 Education_3 Education_4 ... TotalWorkingYears \\\n", "1298 0 1 0 ... 0.591154 \n", "620 0 0 1 ... 0.210630 \n", "1193 0 0 1 ... 0.464312 \n", "139 0 1 0 ... -0.296735 \n", "1165 0 0 0 ... -0.296735 \n", "\n", " TrainingTimesLastYear YearsAtCompany YearsInCurrentRole \\\n", "1298 -0.621514 1.462919 1.913468 \n", "620 -0.621514 -1.131942 -1.175408 \n", "1193 -0.621514 -0.969763 -1.175408 \n", "139 0.153090 -0.969763 -1.175408 \n", "1165 0.153090 -0.483227 -0.332987 \n", "\n", " YearsSinceLastPromotion YearsWithCurrManager JobRole0 JobRole1 \\\n", "1298 1.143272 1.058170 0.512184 0.799463 \n", "620 -0.678787 -1.160026 -0.994429 -0.451462 \n", "1193 -0.678787 -1.160026 -0.994429 -0.451462 \n", "139 -0.678787 -1.160026 -0.994429 -0.451462 \n", "1165 -0.071434 -0.605477 0.512184 0.799463 \n", "\n", " JobRole2 JobRole3 \n", "1298 0.058274 0.899445 \n", "620 -0.816773 -1.706519 \n", "1193 -0.816773 -1.706519 \n", "139 -0.816773 -1.706519 \n", "1165 0.058274 0.899445 \n", "\n", "[5 rows x 74 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run the training data through the preprocessor\n", "X_train_transformed = best_pipe.named_steps['prep'].transform(X_train)\n", "X_train_transformed.head(5)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0])" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "\n", "# Encode target values\n", "le = LabelEncoder().fit(y_train.values.ravel())\n", "y_train_encoded = le.transform(y_train.values.ravel())\n", "y_train_encoded[:10]" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "931/931 [==============================] - 0s 498us/step - loss: 0.2025 - accuracy: 0.2266\n", "Epoch 2/50\n", "931/931 [==============================] - 0s 165us/step - loss: 0.1567 - accuracy: 0.3373\n", "Epoch 3/50\n", "931/931 [==============================] - 0s 178us/step - loss: 0.1286 - accuracy: 0.5800\n", "Epoch 4/50\n", "931/931 [==============================] - 0s 165us/step - loss: 0.1126 - accuracy: 0.6273\n", "Epoch 5/50\n", "931/931 [==============================] - 0s 167us/step - loss: 0.0951 - accuracy: 0.7132\n", "Epoch 6/50\n", "931/931 [==============================] - 0s 164us/step - loss: 0.0796 - accuracy: 0.7444\n", "Epoch 7/50\n", "931/931 [==============================] - 0s 155us/step - loss: 0.0626 - accuracy: 0.8163\n", "Epoch 8/50\n", "931/931 [==============================] - 0s 187us/step - loss: 0.0501 - accuracy: 0.8528\n", "Epoch 9/50\n", "931/931 [==============================] - 0s 202us/step - loss: 0.0432 - accuracy: 0.8937\n", "Epoch 10/50\n", "931/931 [==============================] - 0s 216us/step - loss: 0.0341 - accuracy: 0.9151\n", "Epoch 11/50\n", "931/931 [==============================] - 0s 209us/step - loss: 0.0302 - accuracy: 0.9323\n", "Epoch 12/50\n", "931/931 [==============================] - 0s 201us/step - loss: 0.0207 - accuracy: 0.9538\n", "Epoch 13/50\n", "931/931 [==============================] - 0s 194us/step - loss: 0.0166 - accuracy: 0.9646\n", "Epoch 14/50\n", "931/931 [==============================] - 0s 215us/step - loss: 0.0134 - accuracy: 0.9678\n", "Epoch 15/50\n", "931/931 [==============================] - 0s 178us/step - loss: 0.0094 - accuracy: 0.9817\n", "Epoch 16/50\n", "931/931 [==============================] - 0s 159us/step - loss: 0.0072 - accuracy: 0.9903\n", "Epoch 17/50\n", "931/931 [==============================] - 0s 162us/step - loss: 0.0052 - accuracy: 0.9936\n", "Epoch 18/50\n", "931/931 [==============================] - 0s 162us/step - loss: 0.0041 - accuracy: 0.9946\n", "Epoch 19/50\n", "931/931 [==============================] - 0s 160us/step - loss: 0.0132 - accuracy: 0.9774\n", "Epoch 20/50\n", "931/931 [==============================] - 0s 179us/step - loss: 0.0167 - accuracy: 0.9656\n", "Epoch 21/50\n", "931/931 [==============================] - 0s 185us/step - loss: 0.0070 - accuracy: 0.9850\n", "Epoch 22/50\n", "931/931 [==============================] - 0s 159us/step - loss: 0.0041 - accuracy: 0.9936\n", "Epoch 23/50\n", "931/931 [==============================] - 0s 162us/step - loss: 0.0019 - accuracy: 0.9989\n", "Epoch 24/50\n", "931/931 [==============================] - 0s 165us/step - loss: 0.0014 - accuracy: 0.9989\n", "Epoch 25/50\n", "931/931 [==============================] - 0s 207us/step - loss: 0.0010 - accuracy: 1.0000\n", "Epoch 26/50\n", "931/931 [==============================] - 0s 161us/step - loss: 8.1854e-04 - accuracy: 1.0000\n", "Epoch 27/50\n", "931/931 [==============================] - 0s 163us/step - loss: 6.7515e-04 - accuracy: 1.0000\n", "Epoch 28/50\n", "931/931 [==============================] - 0s 172us/step - loss: 5.9500e-04 - accuracy: 1.0000\n", "Epoch 29/50\n", "931/931 [==============================] - 0s 180us/step - loss: 5.1935e-04 - accuracy: 1.0000\n", "Epoch 30/50\n", "931/931 [==============================] - 0s 197us/step - loss: 4.4537e-04 - accuracy: 1.0000\n", "Epoch 31/50\n", "931/931 [==============================] - 0s 194us/step - loss: 3.9470e-04 - accuracy: 1.0000\n", "Epoch 32/50\n", "931/931 [==============================] - 0s 191us/step - loss: 3.5023e-04 - accuracy: 1.0000\n", "Epoch 33/50\n", "931/931 [==============================] - 0s 196us/step - loss: 3.1019e-04 - accuracy: 1.0000\n", "Epoch 34/50\n", "931/931 [==============================] - 0s 175us/step - loss: 2.7797e-04 - accuracy: 1.0000\n", "Epoch 35/50\n", "931/931 [==============================] - 0s 172us/step - loss: 2.5035e-04 - accuracy: 1.0000\n", "Epoch 36/50\n", "931/931 [==============================] - 0s 163us/step - loss: 2.2762e-04 - accuracy: 1.0000\n", "Epoch 37/50\n", "931/931 [==============================] - 0s 161us/step - loss: 2.0539e-04 - accuracy: 1.0000\n", "Epoch 38/50\n", "931/931 [==============================] - 0s 172us/step - loss: 1.8716e-04 - accuracy: 1.0000\n", "Epoch 39/50\n", "931/931 [==============================] - 0s 164us/step - loss: 1.6805e-04 - accuracy: 1.0000\n", "Epoch 40/50\n", "931/931 [==============================] - 0s 181us/step - loss: 1.5399e-04 - accuracy: 1.0000\n", "Epoch 41/50\n", "931/931 [==============================] - 0s 170us/step - loss: 1.3869e-04 - accuracy: 1.0000\n", "Epoch 42/50\n", "931/931 [==============================] - 0s 175us/step - loss: 1.2607e-04 - accuracy: 1.0000\n", "Epoch 43/50\n", "931/931 [==============================] - 0s 176us/step - loss: 1.1507e-04 - accuracy: 1.0000\n", "Epoch 44/50\n", "931/931 [==============================] - 0s 175us/step - loss: 1.0526e-04 - accuracy: 1.0000\n", "Epoch 45/50\n", "931/931 [==============================] - 0s 165us/step - loss: 9.6587e-05 - accuracy: 1.0000\n", "Epoch 46/50\n", "931/931 [==============================] - 0s 180us/step - loss: 8.8015e-05 - accuracy: 1.0000\n", "Epoch 47/50\n", "931/931 [==============================] - 0s 175us/step - loss: 8.1163e-05 - accuracy: 1.0000\n", "Epoch 48/50\n", "931/931 [==============================] - 0s 185us/step - loss: 7.3699e-05 - accuracy: 1.0000\n", "Epoch 49/50\n", "931/931 [==============================] - 0s 174us/step - loss: 6.7770e-05 - accuracy: 1.0000\n", "Epoch 50/50\n", "931/931 [==============================] - 0s 170us/step - loss: 6.2085e-05 - accuracy: 1.0000\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from keras.models import Sequential\n", "from keras.layers import Dense\n", "\n", "# Define the Keras model\n", "model = Sequential()\n", "model.add(Dense(100, input_dim=74, activation='relu'))\n", "model.add(Dense(50, activation='relu'))\n", "model.add(Dense(1, activation='sigmoid'))\n", "\n", "# Compile the Keras model\n", "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n", "\n", "# Fit the Keras model on the dataset\n", "model.fit(X_train_transformed, y_train_encoded, epochs=50, batch_size=8, class_weight={0:0.1, 1:2.0})" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "399/399 [==============================] - 0s 35us/step\n", "Keras test accuracy: 0.877\n" ] } ], "source": [ "# Run the test data through the preprocessor\n", "X_test_transformed = best_pipe.named_steps['prep'].transform(X_test)\n", "# Encode the test labels\n", "y_test_encoded = le.transform(y_test.values.ravel())\n", "\n", "# Check model accuracy on test data\n", "print('Keras test accuracy: %.3f' % (model.evaluate(X_test_transformed, y_test_encoded)[1]))" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "# Save the keras model architecture and weights to disk\n", "model.save('HR-Attrition-Keras-v1.h5')" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Keras prediction: 0\n" ] } ], "source": [ "import keras\n", "from keras import backend as kerasbackend\n", "\n", "kerasbackend.clear_session()\n", " \n", "# Load the keras model architecture and weights from disk\n", "keras_model = keras.models.load_model('HR-Attrition-Keras-v1.h5')\n", "keras_model._make_predict_function()\n", "\n", "print('Keras prediction: %.0f' % (keras_model.predict(X_test_transformed.iloc[[0]])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }