Merge branch 'main' of https://gitea.jany.se/Jany/MLPproject
This commit is contained in:
270
Analysis.ipynb
270
Analysis.ipynb
@@ -11,7 +11,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"id": "557ed2b5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -26,247 +26,32 @@
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>age</th>\n",
|
||||
" <th>workclass</th>\n",
|
||||
" <th>education.num</th>\n",
|
||||
" <th>marital.status</th>\n",
|
||||
" <th>occupation</th>\n",
|
||||
" <th>relationship</th>\n",
|
||||
" <th>race</th>\n",
|
||||
" <th>sex</th>\n",
|
||||
" <th>capital.gain</th>\n",
|
||||
" <th>capital.loss</th>\n",
|
||||
" <th>hours.per.week</th>\n",
|
||||
" <th>native.country</th>\n",
|
||||
" <th>income</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>82</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>4356</td>\n",
|
||||
" <td>18</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>54</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3900</td>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>41</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3900</td>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>34</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>7</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3770</td>\n",
|
||||
" <td>45</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3770</td>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>74</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>16</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3683</td>\n",
|
||||
" <td>20</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>68</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3683</td>\n",
|
||||
" <td>40</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>45</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>16</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>3004</td>\n",
|
||||
" <td>35</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>15</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>2824</td>\n",
|
||||
" <td>45</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>52</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>7</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>2824</td>\n",
|
||||
" <td>20</td>\n",
|
||||
" <td>38</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" age workclass education.num marital.status occupation relationship \\\n",
|
||||
"1 82 2 9 6 3 1 \n",
|
||||
"3 54 2 4 0 6 4 \n",
|
||||
"4 41 2 10 5 9 3 \n",
|
||||
"5 34 2 9 0 7 4 \n",
|
||||
"6 38 2 6 5 0 4 \n",
|
||||
"7 74 5 16 4 9 2 \n",
|
||||
"8 68 0 9 0 9 1 \n",
|
||||
"10 45 2 16 0 9 4 \n",
|
||||
"11 38 4 15 4 9 1 \n",
|
||||
"12 52 2 13 6 7 1 \n",
|
||||
"\n",
|
||||
" race sex capital.gain capital.loss hours.per.week native.country \\\n",
|
||||
"1 4 0 0 4356 18 38 \n",
|
||||
"3 4 0 0 3900 40 38 \n",
|
||||
"4 4 0 0 3900 40 38 \n",
|
||||
"5 4 0 0 3770 45 38 \n",
|
||||
"6 4 1 0 3770 40 38 \n",
|
||||
"7 4 0 0 3683 20 38 \n",
|
||||
"8 4 0 0 3683 40 38 \n",
|
||||
"10 2 0 0 3004 35 38 \n",
|
||||
"11 4 1 0 2824 45 38 \n",
|
||||
"12 4 0 0 2824 20 38 \n",
|
||||
"\n",
|
||||
" income \n",
|
||||
"1 0 \n",
|
||||
"3 0 \n",
|
||||
"4 0 \n",
|
||||
"5 0 \n",
|
||||
"6 0 \n",
|
||||
"7 1 \n",
|
||||
"8 0 \n",
|
||||
"10 1 \n",
|
||||
"11 1 \n",
|
||||
"12 1 "
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"RangeIndex: 32561 entries, 0 to 32560\n",
|
||||
"Data columns (total 15 columns):\n",
|
||||
" # Column Non-Null Count Dtype \n",
|
||||
"--- ------ -------------- ----- \n",
|
||||
" 0 age 32561 non-null int64 \n",
|
||||
" 1 workclass 32561 non-null object\n",
|
||||
" 2 fnlwgt 32561 non-null int64 \n",
|
||||
" 3 education 32561 non-null object\n",
|
||||
" 4 education.num 32561 non-null int64 \n",
|
||||
" 5 marital.status 32561 non-null object\n",
|
||||
" 6 occupation 32561 non-null object\n",
|
||||
" 7 relationship 32561 non-null object\n",
|
||||
" 8 race 32561 non-null object\n",
|
||||
" 9 sex 32561 non-null object\n",
|
||||
" 10 capital.gain 32561 non-null int64 \n",
|
||||
" 11 capital.loss 32561 non-null int64 \n",
|
||||
" 12 hours.per.week 32561 non-null int64 \n",
|
||||
" 13 native.country 32561 non-null object\n",
|
||||
" 14 income 32561 non-null object\n",
|
||||
"dtypes: int64(6), object(9)\n",
|
||||
"memory usage: 3.7+ MB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -306,6 +91,7 @@
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"#df_encoded.head(10)\n",
|
||||
"df.info()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 24,
|
||||
"id": "0952f099",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -45,7 +45,26 @@
|
||||
" accuracy 0.81 6033\n",
|
||||
" macro avg 0.74 0.75 0.74 6033\n",
|
||||
"weighted avg 0.81 0.81 0.81 6033\n",
|
||||
"\n"
|
||||
"\n",
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"Index: 18096 entries, 13586 to 5836\n",
|
||||
"Data columns (total 12 columns):\n",
|
||||
" # Column Non-Null Count Dtype\n",
|
||||
"--- ------ -------------- -----\n",
|
||||
" 0 age 18096 non-null int64\n",
|
||||
" 1 workclass 18096 non-null int64\n",
|
||||
" 2 education.num 18096 non-null int64\n",
|
||||
" 3 marital.status 18096 non-null int64\n",
|
||||
" 4 occupation 18096 non-null int64\n",
|
||||
" 5 relationship 18096 non-null int64\n",
|
||||
" 6 race 18096 non-null int64\n",
|
||||
" 7 sex 18096 non-null int64\n",
|
||||
" 8 capital.gain 18096 non-null int64\n",
|
||||
" 9 capital.loss 18096 non-null int64\n",
|
||||
" 10 hours.per.week 18096 non-null int64\n",
|
||||
" 11 native.country 18096 non-null int64\n",
|
||||
"dtypes: int64(12)\n",
|
||||
"memory usage: 1.8 MB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -53,12 +72,14 @@
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
|
||||
"from sklearn.pipeline import Pipeline\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier, plot_tree\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n",
|
||||
"from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
|
||||
"from scipy.stats import randint\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Load dataset\n",
|
||||
"df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n",
|
||||
@@ -123,16 +144,63 @@
|
||||
"plt.show() \n",
|
||||
"\n",
|
||||
"print(\"Classification Report:\")\n",
|
||||
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))"
|
||||
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))\n",
|
||||
"\n",
|
||||
"X_train.info()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "e567e4e9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Classification Report:\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Poor 0.89 0.92 0.90 4524\n",
|
||||
" Rich 0.73 0.65 0.68 1509\n",
|
||||
"\n",
|
||||
" accuracy 0.85 6033\n",
|
||||
" macro avg 0.81 0.78 0.79 6033\n",
|
||||
"weighted avg 0.85 0.85 0.85 6033\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Hyperparameters search space\n",
|
||||
"param_dist = {\n",
|
||||
" 'full_dt_classifier__max_depth': randint(3, 20),\n",
|
||||
" 'full_dt_classifier__min_samples_split': randint(2, 10),\n",
|
||||
" 'full_dt_classifier__min_samples_leaf': randint(1, 10),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Ranodmized search for hyperparameter tuning\n",
|
||||
"random_search = RandomizedSearchCV(\n",
|
||||
" estimator=model,\n",
|
||||
" param_distributions=param_dist,\n",
|
||||
" n_iter = 50,\n",
|
||||
" cv = 10,\n",
|
||||
" scoring = 'r2',\n",
|
||||
" n_jobs = -1,\n",
|
||||
" random_state = 42\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Fit search\n",
|
||||
"random_search.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
"# Best model training\n",
|
||||
"best_model = random_search.best_estimator_\n",
|
||||
"y_pred_best = best_model.predict(X_val)\n",
|
||||
"\n",
|
||||
"print(\"Classification Report:\")\n",
|
||||
"print(classification_report(y_val, y_pred_best, target_names=[\"Poor\", \"Rich\"]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user