This commit is contained in:
2025-10-22 12:34:23 +02:00
3 changed files with 103 additions and 249 deletions

View File

@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "557ed2b5",
"metadata": {},
"outputs": [
@@ -26,247 +26,32 @@
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>education.num</th>\n",
" <th>marital.status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capital.gain</th>\n",
" <th>capital.loss</th>\n",
" <th>hours.per.week</th>\n",
" <th>native.country</th>\n",
" <th>income</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>82</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4356</td>\n",
" <td>18</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>54</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3900</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>41</td>\n",
" <td>2</td>\n",
" <td>10</td>\n",
" <td>5</td>\n",
" <td>9</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3900</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>34</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3770</td>\n",
" <td>45</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>38</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3770</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>74</td>\n",
" <td>5</td>\n",
" <td>16</td>\n",
" <td>4</td>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3683</td>\n",
" <td>20</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>68</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3683</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>45</td>\n",
" <td>2</td>\n",
" <td>16</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3004</td>\n",
" <td>35</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>38</td>\n",
" <td>4</td>\n",
" <td>15</td>\n",
" <td>4</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2824</td>\n",
" <td>45</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>52</td>\n",
" <td>2</td>\n",
" <td>13</td>\n",
" <td>6</td>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2824</td>\n",
" <td>20</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass education.num marital.status occupation relationship \\\n",
"1 82 2 9 6 3 1 \n",
"3 54 2 4 0 6 4 \n",
"4 41 2 10 5 9 3 \n",
"5 34 2 9 0 7 4 \n",
"6 38 2 6 5 0 4 \n",
"7 74 5 16 4 9 2 \n",
"8 68 0 9 0 9 1 \n",
"10 45 2 16 0 9 4 \n",
"11 38 4 15 4 9 1 \n",
"12 52 2 13 6 7 1 \n",
"\n",
" race sex capital.gain capital.loss hours.per.week native.country \\\n",
"1 4 0 0 4356 18 38 \n",
"3 4 0 0 3900 40 38 \n",
"4 4 0 0 3900 40 38 \n",
"5 4 0 0 3770 45 38 \n",
"6 4 1 0 3770 40 38 \n",
"7 4 0 0 3683 20 38 \n",
"8 4 0 0 3683 40 38 \n",
"10 2 0 0 3004 35 38 \n",
"11 4 1 0 2824 45 38 \n",
"12 4 0 0 2824 20 38 \n",
"\n",
" income \n",
"1 0 \n",
"3 0 \n",
"4 0 \n",
"5 0 \n",
"6 0 \n",
"7 1 \n",
"8 0 \n",
"10 1 \n",
"11 1 \n",
"12 1 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 32561 entries, 0 to 32560\n",
"Data columns (total 15 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 age 32561 non-null int64 \n",
" 1 workclass 32561 non-null object\n",
" 2 fnlwgt 32561 non-null int64 \n",
" 3 education 32561 non-null object\n",
" 4 education.num 32561 non-null int64 \n",
" 5 marital.status 32561 non-null object\n",
" 6 occupation 32561 non-null object\n",
" 7 relationship 32561 non-null object\n",
" 8 race 32561 non-null object\n",
" 9 sex 32561 non-null object\n",
" 10 capital.gain 32561 non-null int64 \n",
" 11 capital.loss 32561 non-null int64 \n",
" 12 hours.per.week 32561 non-null int64 \n",
" 13 native.country 32561 non-null object\n",
" 14 income 32561 non-null object\n",
"dtypes: int64(6), object(9)\n",
"memory usage: 3.7+ MB\n"
]
}
],
"source": [
@@ -306,6 +91,7 @@
"plt.show()\n",
"\n",
"#df_encoded.head(10)\n",
"df.info()\n",
"\n"
]
},

View File

@@ -8,7 +8,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 24,
"id": "0952f099",
"metadata": {},
"outputs": [
@@ -45,7 +45,26 @@
" accuracy 0.81 6033\n",
" macro avg 0.74 0.75 0.74 6033\n",
"weighted avg 0.81 0.81 0.81 6033\n",
"\n"
"\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 18096 entries, 13586 to 5836\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype\n",
"--- ------ -------------- -----\n",
" 0 age 18096 non-null int64\n",
" 1 workclass 18096 non-null int64\n",
" 2 education.num 18096 non-null int64\n",
" 3 marital.status 18096 non-null int64\n",
" 4 occupation 18096 non-null int64\n",
" 5 relationship 18096 non-null int64\n",
" 6 race 18096 non-null int64\n",
" 7 sex 18096 non-null int64\n",
" 8 capital.gain 18096 non-null int64\n",
" 9 capital.loss 18096 non-null int64\n",
" 10 hours.per.week 18096 non-null int64\n",
" 11 native.country 18096 non-null int64\n",
"dtypes: int64(12)\n",
"memory usage: 1.8 MB\n"
]
}
],
@@ -53,12 +72,14 @@
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.tree import DecisionTreeClassifier, plot_tree\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n",
"from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
"from scipy.stats import randint\n",
"\n",
"\n",
"# Load dataset\n",
"df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n",
@@ -123,16 +144,63 @@
"plt.show() \n",
"\n",
"print(\"Classification Report:\")\n",
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))"
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))\n",
"\n",
"X_train.info()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "e567e4e9",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Poor 0.89 0.92 0.90 4524\n",
" Rich 0.73 0.65 0.68 1509\n",
"\n",
" accuracy 0.85 6033\n",
" macro avg 0.81 0.78 0.79 6033\n",
"weighted avg 0.85 0.85 0.85 6033\n",
"\n"
]
}
],
"source": [
"# Hyperparameters search space\n",
"param_dist = {\n",
" 'full_dt_classifier__max_depth': randint(3, 20),\n",
" 'full_dt_classifier__min_samples_split': randint(2, 10),\n",
" 'full_dt_classifier__min_samples_leaf': randint(1, 10),\n",
"}\n",
"\n",
"# Ranodmized search for hyperparameter tuning\n",
"random_search = RandomizedSearchCV(\n",
" estimator=model,\n",
" param_distributions=param_dist,\n",
" n_iter = 50,\n",
" cv = 10,\n",
" scoring = 'r2',\n",
" n_jobs = -1,\n",
" random_state = 42\n",
")\n",
"\n",
"# Fit search\n",
"random_search.fit(X_train, y_train)\n",
"\n",
"# Best model training\n",
"best_model = random_search.best_estimator_\n",
"y_pred_best = best_model.predict(X_val)\n",
"\n",
"print(\"Classification Report:\")\n",
"print(classification_report(y_val, y_pred_best, target_names=[\"Poor\", \"Rich\"]))"
]
}
],
"metadata": {

Binary file not shown.