Merge branch 'main' of https://gitea.jany.se/Jany/MLPproject
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 24,
|
||||
"id": "0952f099",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -45,7 +45,26 @@
|
||||
" accuracy 0.81 6033\n",
|
||||
" macro avg 0.74 0.75 0.74 6033\n",
|
||||
"weighted avg 0.81 0.81 0.81 6033\n",
|
||||
"\n"
|
||||
"\n",
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"Index: 18096 entries, 13586 to 5836\n",
|
||||
"Data columns (total 12 columns):\n",
|
||||
" # Column Non-Null Count Dtype\n",
|
||||
"--- ------ -------------- -----\n",
|
||||
" 0 age 18096 non-null int64\n",
|
||||
" 1 workclass 18096 non-null int64\n",
|
||||
" 2 education.num 18096 non-null int64\n",
|
||||
" 3 marital.status 18096 non-null int64\n",
|
||||
" 4 occupation 18096 non-null int64\n",
|
||||
" 5 relationship 18096 non-null int64\n",
|
||||
" 6 race 18096 non-null int64\n",
|
||||
" 7 sex 18096 non-null int64\n",
|
||||
" 8 capital.gain 18096 non-null int64\n",
|
||||
" 9 capital.loss 18096 non-null int64\n",
|
||||
" 10 hours.per.week 18096 non-null int64\n",
|
||||
" 11 native.country 18096 non-null int64\n",
|
||||
"dtypes: int64(12)\n",
|
||||
"memory usage: 1.8 MB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -53,12 +72,14 @@
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
|
||||
"from sklearn.pipeline import Pipeline\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier, plot_tree\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"from sklearn.metrics import mean_squared_error, mean_absolute_error , r2_score\n",
|
||||
"from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
|
||||
"from scipy.stats import randint\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Load dataset\n",
|
||||
"df = pd.read_csv('./Datasets/adult.csv', comment = '#')\n",
|
||||
@@ -123,16 +144,63 @@
|
||||
"plt.show() \n",
|
||||
"\n",
|
||||
"print(\"Classification Report:\")\n",
|
||||
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))"
|
||||
"print(classification_report(y_val, y_pred, target_names=[\"Poor\", \"Rich\"]))\n",
|
||||
"\n",
|
||||
"X_train.info()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "e567e4e9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Classification Report:\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" Poor 0.89 0.92 0.90 4524\n",
|
||||
" Rich 0.73 0.65 0.68 1509\n",
|
||||
"\n",
|
||||
" accuracy 0.85 6033\n",
|
||||
" macro avg 0.81 0.78 0.79 6033\n",
|
||||
"weighted avg 0.85 0.85 0.85 6033\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Hyperparameters search space\n",
|
||||
"param_dist = {\n",
|
||||
" 'full_dt_classifier__max_depth': randint(3, 20),\n",
|
||||
" 'full_dt_classifier__min_samples_split': randint(2, 10),\n",
|
||||
" 'full_dt_classifier__min_samples_leaf': randint(1, 10),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Ranodmized search for hyperparameter tuning\n",
|
||||
"random_search = RandomizedSearchCV(\n",
|
||||
" estimator=model,\n",
|
||||
" param_distributions=param_dist,\n",
|
||||
" n_iter = 50,\n",
|
||||
" cv = 10,\n",
|
||||
" scoring = 'r2',\n",
|
||||
" n_jobs = -1,\n",
|
||||
" random_state = 42\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Fit search\n",
|
||||
"random_search.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
"# Best model training\n",
|
||||
"best_model = random_search.best_estimator_\n",
|
||||
"y_pred_best = best_model.predict(X_val)\n",
|
||||
"\n",
|
||||
"print(\"Classification Report:\")\n",
|
||||
"print(classification_report(y_val, y_pred_best, target_names=[\"Poor\", \"Rich\"]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user