diff --git a/Analysis.ipynb b/Analysis.ipynb
index 17c4369b..4b2c9f80 100644
--- a/Analysis.ipynb
+++ b/Analysis.ipynb
@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "557ed2b5",
"metadata": {},
"outputs": [
@@ -26,247 +26,32 @@
"output_type": "display_data"
},
{
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " age | \n",
- " workclass | \n",
- " education.num | \n",
- " marital.status | \n",
- " occupation | \n",
- " relationship | \n",
- " race | \n",
- " sex | \n",
- " capital.gain | \n",
- " capital.loss | \n",
- " hours.per.week | \n",
- " native.country | \n",
- " income | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 1 | \n",
- " 82 | \n",
- " 2 | \n",
- " 9 | \n",
- " 6 | \n",
- " 3 | \n",
- " 1 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 4356 | \n",
- " 18 | \n",
- " 38 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " | 3 | \n",
- " 54 | \n",
- " 2 | \n",
- " 4 | \n",
- " 0 | \n",
- " 6 | \n",
- " 4 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 3900 | \n",
- " 40 | \n",
- " 38 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " | 4 | \n",
- " 41 | \n",
- " 2 | \n",
- " 10 | \n",
- " 5 | \n",
- " 9 | \n",
- " 3 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 3900 | \n",
- " 40 | \n",
- " 38 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " | 5 | \n",
- " 34 | \n",
- " 2 | \n",
- " 9 | \n",
- " 0 | \n",
- " 7 | \n",
- " 4 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 3770 | \n",
- " 45 | \n",
- " 38 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " | 6 | \n",
- " 38 | \n",
- " 2 | \n",
- " 6 | \n",
- " 5 | \n",
- " 0 | \n",
- " 4 | \n",
- " 4 | \n",
- " 1 | \n",
- " 0 | \n",
- " 3770 | \n",
- " 40 | \n",
- " 38 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " | 7 | \n",
- " 74 | \n",
- " 5 | \n",
- " 16 | \n",
- " 4 | \n",
- " 9 | \n",
- " 2 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 3683 | \n",
- " 20 | \n",
- " 38 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " | 8 | \n",
- " 68 | \n",
- " 0 | \n",
- " 9 | \n",
- " 0 | \n",
- " 9 | \n",
- " 1 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 3683 | \n",
- " 40 | \n",
- " 38 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " | 10 | \n",
- " 45 | \n",
- " 2 | \n",
- " 16 | \n",
- " 0 | \n",
- " 9 | \n",
- " 4 | \n",
- " 2 | \n",
- " 0 | \n",
- " 0 | \n",
- " 3004 | \n",
- " 35 | \n",
- " 38 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " | 11 | \n",
- " 38 | \n",
- " 4 | \n",
- " 15 | \n",
- " 4 | \n",
- " 9 | \n",
- " 1 | \n",
- " 4 | \n",
- " 1 | \n",
- " 0 | \n",
- " 2824 | \n",
- " 45 | \n",
- " 38 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " | 12 | \n",
- " 52 | \n",
- " 2 | \n",
- " 13 | \n",
- " 6 | \n",
- " 7 | \n",
- " 1 | \n",
- " 4 | \n",
- " 0 | \n",
- " 0 | \n",
- " 2824 | \n",
- " 20 | \n",
- " 38 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " age workclass education.num marital.status occupation relationship \\\n",
- "1 82 2 9 6 3 1 \n",
- "3 54 2 4 0 6 4 \n",
- "4 41 2 10 5 9 3 \n",
- "5 34 2 9 0 7 4 \n",
- "6 38 2 6 5 0 4 \n",
- "7 74 5 16 4 9 2 \n",
- "8 68 0 9 0 9 1 \n",
- "10 45 2 16 0 9 4 \n",
- "11 38 4 15 4 9 1 \n",
- "12 52 2 13 6 7 1 \n",
- "\n",
- " race sex capital.gain capital.loss hours.per.week native.country \\\n",
- "1 4 0 0 4356 18 38 \n",
- "3 4 0 0 3900 40 38 \n",
- "4 4 0 0 3900 40 38 \n",
- "5 4 0 0 3770 45 38 \n",
- "6 4 1 0 3770 40 38 \n",
- "7 4 0 0 3683 20 38 \n",
- "8 4 0 0 3683 40 38 \n",
- "10 2 0 0 3004 35 38 \n",
- "11 4 1 0 2824 45 38 \n",
- "12 4 0 0 2824 20 38 \n",
- "\n",
- " income \n",
- "1 0 \n",
- "3 0 \n",
- "4 0 \n",
- "5 0 \n",
- "6 0 \n",
- "7 1 \n",
- "8 0 \n",
- "10 1 \n",
- "11 1 \n",
- "12 1 "
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "RangeIndex: 32561 entries, 0 to 32560\n",
+ "Data columns (total 15 columns):\n",
+ " # Column Non-Null Count Dtype \n",
+ "--- ------ -------------- ----- \n",
+ " 0 age 32561 non-null int64 \n",
+ " 1 workclass 32561 non-null object\n",
+ " 2 fnlwgt 32561 non-null int64 \n",
+ " 3 education 32561 non-null object\n",
+ " 4 education.num 32561 non-null int64 \n",
+ " 5 marital.status 32561 non-null object\n",
+ " 6 occupation 32561 non-null object\n",
+ " 7 relationship 32561 non-null object\n",
+ " 8 race 32561 non-null object\n",
+ " 9 sex 32561 non-null object\n",
+ " 10 capital.gain 32561 non-null int64 \n",
+ " 11 capital.loss 32561 non-null int64 \n",
+ " 12 hours.per.week 32561 non-null int64 \n",
+ " 13 native.country 32561 non-null object\n",
+ " 14 income 32561 non-null object\n",
+ "dtypes: int64(6), object(9)\n",
+ "memory usage: 3.7+ MB\n"
+ ]
}
],
"source": [
@@ -306,6 +91,7 @@
"plt.show()\n",
"\n",
"#df_encoded.head(10)\n",
+ "df.info()\n",
"\n"
]
},
diff --git a/Decision_tree.ipynb b/Decision_tree.ipynb
index d61600ef..32ff772d 100644
--- a/Decision_tree.ipynb
+++ b/Decision_tree.ipynb
@@ -8,7 +8,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 24,
"id": "0952f099",
"metadata": {},
"outputs": [
@@ -45,7 +45,26 @@
" accuracy 0.81 6033\n",
" macro avg 0.74 0.75 0.74 6033\n",
"weighted avg 0.81 0.81 0.81 6033\n",
- "\n"
+ "\n",
+ "\n",
+ "Index: 18096 entries, 13586 to 5836\n",
+ "Data columns (total 12 columns):\n",
+ " # Column Non-Null Count Dtype\n",
+ "--- ------ -------------- -----\n",
+ " 0 age 18096 non-null int64\n",
+ " 1 workclass 18096 non-null int64\n",
+ " 2 education.num 18096 non-null int64\n",
+ " 3 marital.status 18096 non-null int64\n",
+ " 4 occupation 18096 non-null int64\n",
+ " 5 relationship 18096 non-null int64\n",
+ " 6 race 18096 non-null int64\n",
+ " 7 sex 18096 non-null int64\n",
+ " 8 capital.gain 18096 non-null int64\n",
+ " 9 capital.loss 18096 non-null int64\n",
+ " 10 hours.per.week 18096 non-null int64\n",
+ " 11 native.country 18096 non-null int64\n",
+ "dtypes: int64(12)\n",
+ "memory usage: 1.8 MB\n"
]
}
],
@@ -53,12 +72,14 @@
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
- "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.tree import DecisionTreeClassifier, plot_tree\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n",
"from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
+ "from scipy.stats import randint\n",
+ "\n",
"\n",
"# Load dataset\n",
"df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n",
@@ -123,16 +144,63 @@
"plt.show() \n",
"\n",
"print(\"Classification Report:\")\n",
- "print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))"
+ "print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))\n",
+ "\n",
+ "X_train.info()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 26,
"id": "e567e4e9",
"metadata": {},
- "outputs": [],
- "source": []
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Classification Report:\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " Poor 0.89 0.92 0.90 4524\n",
+ " Rich 0.73 0.65 0.68 1509\n",
+ "\n",
+ " accuracy 0.85 6033\n",
+ " macro avg 0.81 0.78 0.79 6033\n",
+ "weighted avg 0.85 0.85 0.85 6033\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Hyperparameters search space\n",
+ "param_dist = {\n",
+ " 'full_dt_classifier__max_depth': randint(3, 20),\n",
+ " 'full_dt_classifier__min_samples_split': randint(2, 10),\n",
+ " 'full_dt_classifier__min_samples_leaf': randint(1, 10),\n",
+ "}\n",
+ "\n",
+ "# Ranodmized search for hyperparameter tuning\n",
+ "random_search = RandomizedSearchCV(\n",
+ " estimator=model,\n",
+ " param_distributions=param_dist,\n",
+ " n_iter = 50,\n",
+ " cv = 10,\n",
+ " scoring = 'r2',\n",
+ " n_jobs = -1,\n",
+ " random_state = 42\n",
+ ")\n",
+ "\n",
+ "# Fit search\n",
+ "random_search.fit(X_train, y_train)\n",
+ "\n",
+ "# Best model training\n",
+ "best_model = random_search.best_estimator_\n",
+ "y_pred_best = best_model.predict(X_val)\n",
+ "\n",
+ "print(\"Classification Report:\")\n",
+ "print(classification_report(y_val, y_pred_best, target_names=[\"Poor\", \"Rich\"]))"
+ ]
}
],
"metadata": {
diff --git a/decision_tree.pdf b/decision_tree.pdf
index 628619bd..4b67ae7d 100644
Binary files a/decision_tree.pdf and b/decision_tree.pdf differ