From d3ab195d635a64e34908d28d37ef039c06e3da8c Mon Sep 17 00:00:00 2001 From: Petrus Date: Wed, 22 Oct 2025 11:36:56 +0200 Subject: [PATCH] Added CV --- Analysis.ipynb | 270 +++++--------------------------------------- Decision_tree.ipynb | 82 ++++++++++++-- decision_tree.pdf | Bin 31960 -> 31960 bytes 3 files changed, 103 insertions(+), 249 deletions(-) diff --git a/Analysis.ipynb b/Analysis.ipynb index 17c4369b..4b2c9f80 100644 --- a/Analysis.ipynb +++ b/Analysis.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "557ed2b5", "metadata": {}, "outputs": [ @@ -26,247 +26,32 @@ "output_type": "display_data" }, { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclasseducation.nummarital.statusoccupationrelationshipracesexcapital.gaincapital.losshours.per.weeknative.countryincome
18229631400435618380
35424064400390040380
441210593400390040380
53429074400377045380
63826504410377040380
774516492400368320381
86809091400368340380
1045216094200300435381
1138415491410282445381
1252213671400282420381
\n", - "
" - ], - "text/plain": [ - " age workclass education.num marital.status occupation relationship \\\n", - "1 82 2 9 6 3 1 \n", - "3 54 2 4 0 6 4 \n", - "4 41 2 10 5 9 3 \n", - "5 34 2 9 0 7 4 \n", - "6 38 2 6 5 0 4 \n", - "7 74 5 16 4 9 2 \n", - "8 68 0 9 0 9 1 \n", - "10 45 2 16 0 9 4 \n", - "11 38 4 15 4 9 1 \n", - "12 52 2 13 6 7 1 \n", - "\n", - " race sex capital.gain capital.loss hours.per.week native.country \\\n", - "1 4 0 0 4356 18 38 \n", - "3 4 0 0 3900 40 38 \n", - "4 4 0 0 3900 40 38 \n", - "5 4 0 0 3770 45 38 \n", - "6 4 1 0 3770 40 38 \n", - "7 4 0 0 3683 20 38 \n", - "8 4 0 0 3683 40 38 \n", - "10 2 0 0 3004 35 38 \n", - "11 4 1 0 2824 45 38 \n", - "12 4 0 0 2824 20 38 \n", - "\n", - " income \n", - "1 0 \n", - "3 0 \n", - "4 0 \n", - "5 0 \n", - "6 0 \n", - "7 1 \n", - "8 0 \n", - "10 1 \n", - "11 1 \n", - "12 1 " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 32561 entries, 0 to 32560\n", + "Data columns (total 15 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 age 32561 non-null int64 \n", + " 1 workclass 32561 non-null object\n", + " 2 fnlwgt 32561 non-null int64 \n", + " 3 education 32561 non-null object\n", + " 4 education.num 32561 non-null int64 \n", + " 5 marital.status 32561 non-null object\n", + " 6 occupation 32561 non-null object\n", + " 7 relationship 32561 non-null object\n", + " 8 race 32561 non-null object\n", + " 9 sex 32561 non-null object\n", + " 10 capital.gain 32561 non-null int64 \n", + " 11 capital.loss 32561 non-null int64 \n", + " 12 hours.per.week 32561 non-null int64 \n", + " 13 native.country 32561 non-null object\n", + " 14 income 32561 non-null object\n", + "dtypes: int64(6), object(9)\n", + "memory usage: 3.7+ MB\n" + ] } ], "source": [ @@ -306,6 +91,7 @@ "plt.show()\n", "\n", "#df_encoded.head(10)\n", + "df.info()\n", "\n" ] }, diff --git a/Decision_tree.ipynb b/Decision_tree.ipynb index d61600ef..32ff772d 100644 --- a/Decision_tree.ipynb +++ b/Decision_tree.ipynb @@ -8,7 +8,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 24, "id": "0952f099", "metadata": {}, "outputs": [ @@ -45,7 +45,26 @@ " accuracy 0.81 6033\n", " macro avg 0.74 0.75 0.74 6033\n", "weighted avg 0.81 0.81 0.81 6033\n", - "\n" + "\n", + "\n", + "Index: 18096 entries, 13586 to 5836\n", + "Data columns (total 12 columns):\n", + " # Column Non-Null Count Dtype\n", + "--- ------ -------------- -----\n", + " 0 age 18096 non-null int64\n", + " 1 workclass 18096 non-null int64\n", + " 2 education.num 18096 non-null int64\n", + " 3 marital.status 18096 non-null int64\n", + " 4 occupation 18096 non-null int64\n", + " 5 relationship 18096 non-null int64\n", + " 6 race 18096 non-null int64\n", + " 7 sex 18096 non-null int64\n", + " 8 capital.gain 18096 non-null int64\n", + " 9 capital.loss 18096 non-null int64\n", + " 10 hours.per.week 18096 non-null int64\n", + " 11 native.country 18096 non-null int64\n", + "dtypes: int64(12)\n", + "memory usage: 1.8 MB\n" ] } ], @@ -53,12 +72,14 @@ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.tree import DecisionTreeClassifier, plot_tree\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n", "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n", + "from scipy.stats import randint\n", + "\n", "\n", "# Load dataset\n", "df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n", @@ -123,16 +144,63 @@ "plt.show() \n", "\n", "print(\"Classification Report:\")\n", - "print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))" + "print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))\n", + "\n", + "X_train.info()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "id": "e567e4e9", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Classification Report:\n", + " precision recall f1-score support\n", + "\n", + " Poor 0.89 0.92 0.90 4524\n", + " Rich 0.73 0.65 0.68 1509\n", + "\n", + " accuracy 0.85 6033\n", + " macro avg 0.81 0.78 0.79 6033\n", + "weighted avg 0.85 0.85 0.85 6033\n", + "\n" + ] + } + ], + "source": [ + "# Hyperparameters search space\n", + "param_dist = {\n", + " 'full_dt_classifier__max_depth': randint(3, 20),\n", + " 'full_dt_classifier__min_samples_split': randint(2, 10),\n", + " 'full_dt_classifier__min_samples_leaf': randint(1, 10),\n", + "}\n", + "\n", + "# Ranodmized search for hyperparameter tuning\n", + "random_search = RandomizedSearchCV(\n", + " estimator=model,\n", + " param_distributions=param_dist,\n", + " n_iter = 50,\n", + " cv = 10,\n", + " scoring = 'r2',\n", + " n_jobs = -1,\n", + " random_state = 42\n", + ")\n", + "\n", + "# Fit search\n", + "random_search.fit(X_train, y_train)\n", + "\n", + "# Best model training\n", + "best_model = random_search.best_estimator_\n", + "y_pred_best = best_model.predict(X_val)\n", + "\n", + "print(\"Classification Report:\")\n", + "print(classification_report(y_val, y_pred_best, target_names=[\"Poor\", \"Rich\"]))" + ] } ], "metadata": { diff --git a/decision_tree.pdf b/decision_tree.pdf index 628619bde5e20bc8f1d8d65c471d96e4755fa441..4b67ae7dbf09c093020feea4f8bc4da09abaddf5 100644 GIT binary patch delta 22 ecmccdlkvt+#to~=*^LYhjf{*eH*YLgX9WOw^9bhv delta 22 ecmccdlkvt+#to~=*$oX%j7*G7H*YLgX9WOw{Rrm(