Compare commits

..

2 Commits

Author SHA1 Message Date
87901ed254 Merge branch 'main' of https://gitea.jany.se/Jany/MLPproject 2025-10-22 11:37:10 +02:00
d3ab195d63 Added CV 2025-10-22 11:36:56 +02:00
3 changed files with 103 additions and 249 deletions

View File

@@ -11,7 +11,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"id": "557ed2b5", "id": "557ed2b5",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -26,247 +26,32 @@
"output_type": "display_data" "output_type": "display_data"
}, },
{ {
"data": { "name": "stdout",
"text/html": [ "output_type": "stream",
"<div>\n", "text": [
"<style scoped>\n", "<class 'pandas.core.frame.DataFrame'>\n",
" .dataframe tbody tr th:only-of-type {\n", "RangeIndex: 32561 entries, 0 to 32560\n",
" vertical-align: middle;\n", "Data columns (total 15 columns):\n",
" }\n", " # Column Non-Null Count Dtype \n",
"\n", "--- ------ -------------- ----- \n",
" .dataframe tbody tr th {\n", " 0 age 32561 non-null int64 \n",
" vertical-align: top;\n", " 1 workclass 32561 non-null object\n",
" }\n", " 2 fnlwgt 32561 non-null int64 \n",
"\n", " 3 education 32561 non-null object\n",
" .dataframe thead th {\n", " 4 education.num 32561 non-null int64 \n",
" text-align: right;\n", " 5 marital.status 32561 non-null object\n",
" }\n", " 6 occupation 32561 non-null object\n",
"</style>\n", " 7 relationship 32561 non-null object\n",
"<table border=\"1\" class=\"dataframe\">\n", " 8 race 32561 non-null object\n",
" <thead>\n", " 9 sex 32561 non-null object\n",
" <tr style=\"text-align: right;\">\n", " 10 capital.gain 32561 non-null int64 \n",
" <th></th>\n", " 11 capital.loss 32561 non-null int64 \n",
" <th>age</th>\n", " 12 hours.per.week 32561 non-null int64 \n",
" <th>workclass</th>\n", " 13 native.country 32561 non-null object\n",
" <th>education.num</th>\n", " 14 income 32561 non-null object\n",
" <th>marital.status</th>\n", "dtypes: int64(6), object(9)\n",
" <th>occupation</th>\n", "memory usage: 3.7+ MB\n"
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>sex</th>\n",
" <th>capital.gain</th>\n",
" <th>capital.loss</th>\n",
" <th>hours.per.week</th>\n",
" <th>native.country</th>\n",
" <th>income</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>82</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>6</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>4356</td>\n",
" <td>18</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>54</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3900</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>41</td>\n",
" <td>2</td>\n",
" <td>10</td>\n",
" <td>5</td>\n",
" <td>9</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3900</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>34</td>\n",
" <td>2</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3770</td>\n",
" <td>45</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>38</td>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3770</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>74</td>\n",
" <td>5</td>\n",
" <td>16</td>\n",
" <td>4</td>\n",
" <td>9</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3683</td>\n",
" <td>20</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>68</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3683</td>\n",
" <td>40</td>\n",
" <td>38</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>45</td>\n",
" <td>2</td>\n",
" <td>16</td>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>3004</td>\n",
" <td>35</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>38</td>\n",
" <td>4</td>\n",
" <td>15</td>\n",
" <td>4</td>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2824</td>\n",
" <td>45</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>52</td>\n",
" <td>2</td>\n",
" <td>13</td>\n",
" <td>6</td>\n",
" <td>7</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2824</td>\n",
" <td>20</td>\n",
" <td>38</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass education.num marital.status occupation relationship \\\n",
"1 82 2 9 6 3 1 \n",
"3 54 2 4 0 6 4 \n",
"4 41 2 10 5 9 3 \n",
"5 34 2 9 0 7 4 \n",
"6 38 2 6 5 0 4 \n",
"7 74 5 16 4 9 2 \n",
"8 68 0 9 0 9 1 \n",
"10 45 2 16 0 9 4 \n",
"11 38 4 15 4 9 1 \n",
"12 52 2 13 6 7 1 \n",
"\n",
" race sex capital.gain capital.loss hours.per.week native.country \\\n",
"1 4 0 0 4356 18 38 \n",
"3 4 0 0 3900 40 38 \n",
"4 4 0 0 3900 40 38 \n",
"5 4 0 0 3770 45 38 \n",
"6 4 1 0 3770 40 38 \n",
"7 4 0 0 3683 20 38 \n",
"8 4 0 0 3683 40 38 \n",
"10 2 0 0 3004 35 38 \n",
"11 4 1 0 2824 45 38 \n",
"12 4 0 0 2824 20 38 \n",
"\n",
" income \n",
"1 0 \n",
"3 0 \n",
"4 0 \n",
"5 0 \n",
"6 0 \n",
"7 1 \n",
"8 0 \n",
"10 1 \n",
"11 1 \n",
"12 1 "
] ]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
@@ -306,6 +91,7 @@
"plt.show()\n", "plt.show()\n",
"\n", "\n",
"#df_encoded.head(10)\n", "#df_encoded.head(10)\n",
"df.info()\n",
"\n" "\n"
] ]
}, },

View File

@@ -8,7 +8,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 24,
"id": "0952f099", "id": "0952f099",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -45,7 +45,26 @@
" accuracy 0.81 6033\n", " accuracy 0.81 6033\n",
" macro avg 0.74 0.75 0.74 6033\n", " macro avg 0.74 0.75 0.74 6033\n",
"weighted avg 0.81 0.81 0.81 6033\n", "weighted avg 0.81 0.81 0.81 6033\n",
"\n" "\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 18096 entries, 13586 to 5836\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype\n",
"--- ------ -------------- -----\n",
" 0 age 18096 non-null int64\n",
" 1 workclass 18096 non-null int64\n",
" 2 education.num 18096 non-null int64\n",
" 3 marital.status 18096 non-null int64\n",
" 4 occupation 18096 non-null int64\n",
" 5 relationship 18096 non-null int64\n",
" 6 race 18096 non-null int64\n",
" 7 sex 18096 non-null int64\n",
" 8 capital.gain 18096 non-null int64\n",
" 9 capital.loss 18096 non-null int64\n",
" 10 hours.per.week 18096 non-null int64\n",
" 11 native.country 18096 non-null int64\n",
"dtypes: int64(12)\n",
"memory usage: 1.8 MB\n"
] ]
} }
], ],
@@ -53,12 +72,14 @@
"import pandas as pd\n", "import pandas as pd\n",
"import numpy as np\n", "import numpy as np\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
"from sklearn.pipeline import Pipeline\n", "from sklearn.pipeline import Pipeline\n",
"from sklearn.tree import DecisionTreeClassifier, plot_tree\n", "from sklearn.tree import DecisionTreeClassifier, plot_tree\n",
"from sklearn.preprocessing import LabelEncoder\n", "from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n",
"from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n", "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
"from scipy.stats import randint\n",
"\n",
"\n", "\n",
"# Load dataset\n", "# Load dataset\n",
"df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n", "df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n",
@@ -123,16 +144,63 @@
"plt.show() \n", "plt.show() \n",
"\n", "\n",
"print(\"Classification Report:\")\n", "print(\"Classification Report:\")\n",
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))" "print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))\n",
"\n",
"X_train.info()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 26,
"id": "e567e4e9", "id": "e567e4e9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
"source": [] {
"name": "stdout",
"output_type": "stream",
"text": [
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Poor 0.89 0.92 0.90 4524\n",
" Rich 0.73 0.65 0.68 1509\n",
"\n",
" accuracy 0.85 6033\n",
" macro avg 0.81 0.78 0.79 6033\n",
"weighted avg 0.85 0.85 0.85 6033\n",
"\n"
]
}
],
"source": [
"# Hyperparameters search space\n",
"param_dist = {\n",
" 'full_dt_classifier__max_depth': randint(3, 20),\n",
" 'full_dt_classifier__min_samples_split': randint(2, 10),\n",
" 'full_dt_classifier__min_samples_leaf': randint(1, 10),\n",
"}\n",
"\n",
"# Ranodmized search for hyperparameter tuning\n",
"random_search = RandomizedSearchCV(\n",
" estimator=model,\n",
" param_distributions=param_dist,\n",
" n_iter = 50,\n",
" cv = 10,\n",
" scoring = 'r2',\n",
" n_jobs = -1,\n",
" random_state = 42\n",
")\n",
"\n",
"# Fit search\n",
"random_search.fit(X_train, y_train)\n",
"\n",
"# Best model training\n",
"best_model = random_search.best_estimator_\n",
"y_pred_best = best_model.predict(X_val)\n",
"\n",
"print(\"Classification Report:\")\n",
"print(classification_report(y_val, y_pred_best, target_names=[\"Poor\", \"Rich\"]))"
]
} }
], ],
"metadata": { "metadata": {

Binary file not shown.