{
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"papermill": {
"default_parameters": {},
"duration": 21629.871921,
"end_time": "2021-08-01T03:02:57.256605",
"environment_variables": {},
"exception": null,
"input_path": "__notebook__.ipynb",
"output_path": "__notebook__.ipynb",
"parameters": {},
"start_time": "2021-07-31T21:02:27.384684",
"version": "2.3.3"
},
"colab": {
"name": "text classification using BERT.ipynb",
"provenance": []
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Lk52n6Tt0uJO"
},
"source": [
"#Text Classification using BERT\n",
"\n",
"The is a basic implementation of text classification pipeline using BERT. The BERT model used has been taken from [huggingface](https://huggingface.co/transformers/). The dataset used is a custom dataset with two classes (labelled as 0 and 1). It is publically available [here](https://raw.githubusercontent.com/prateekjoshi565/Fine-Tuning-BERT/master/spamdata_v2.csv)."
],
"id": "Lk52n6Tt0uJO"
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:02:35.719217Z",
"iopub.status.busy": "2021-07-31T21:02:35.717605Z",
"iopub.status.idle": "2021-07-31T21:02:35.722099Z",
"shell.execute_reply": "2021-07-31T21:02:35.722966Z",
"shell.execute_reply.started": "2021-07-31T08:17:49.337897Z"
},
"papermill": {
"duration": 0.024617,
"end_time": "2021-07-31T21:02:35.723409",
"exception": false,
"start_time": "2021-07-31T21:02:35.698792",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "788d3220",
"outputId": "54b8c951-4479-4581-9f17-d3f7e23bd477"
},
"source": [
"!pip install transformers"
],
"id": "788d3220",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting transformers\n",
" Downloading transformers-4.10.0-py3-none-any.whl (2.8 MB)\n",
"\u001b[K |鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2.8 MB 4.0 MB/s \n",
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)\n",
"Collecting huggingface-hub>=0.0.12\n",
" Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)\n",
"\u001b[K |鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 50 kB 6.8 MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n",
"Collecting pyyaml>=5.1\n",
" Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)\n",
"\u001b[K |鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 636 kB 51.9 MB/s \n",
"\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n",
"Collecting tokenizers<0.11,>=0.10.1\n",
" Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n",
"\u001b[K |鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 3.3 MB 21.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
"Collecting sacremoses\n",
" Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)\n",
"\u001b[K |鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 895 kB 45.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.6.4)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers) (3.7.4.3)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers) (2.4.7)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n",
"Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers\n",
" Attempting uninstall: pyyaml\n",
" Found existing installation: PyYAML 3.13\n",
" Uninstalling PyYAML-3.13:\n",
" Successfully uninstalled PyYAML-3.13\n",
"Successfully installed huggingface-hub-0.0.16 pyyaml-5.4.1 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.10.0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:05:44.436123Z",
"iopub.status.busy": "2021-07-31T21:05:44.435351Z",
"iopub.status.idle": "2021-07-31T21:06:30.493736Z",
"shell.execute_reply": "2021-07-31T21:06:30.492714Z",
"shell.execute_reply.started": "2021-07-31T08:21:08.023057Z"
},
"papermill": {
"duration": 46.206213,
"end_time": "2021-07-31T21:06:30.494123",
"exception": false,
"start_time": "2021-07-31T21:05:44.287910",
"status": "completed"
},
"tags": [],
"id": "c90ffee0"
},
"source": [
"import csv\n",
"import pickle\n",
"import pandas as pd\n",
"import numpy as np\n",
"train = pd.read_csv(\"spamdata_v2.csv\")\n"
],
"id": "c90ffee0",
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:30.792497Z",
"iopub.status.busy": "2021-07-31T21:06:30.791325Z",
"iopub.status.idle": "2021-07-31T21:06:31.512114Z",
"shell.execute_reply": "2021-07-31T21:06:31.512646Z",
"shell.execute_reply.started": "2021-07-31T08:21:53.599755Z"
},
"papermill": {
"duration": 0.877222,
"end_time": "2021-07-31T21:06:31.512821",
"exception": false,
"start_time": "2021-07-31T21:06:30.635599",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 219
},
"id": "fdb5210b",
"outputId": "421a3979-dfdd-4a0f-d331-4b328ba545f3"
},
"source": [
"print(len(train))\n",
"train.head()"
],
"id": "fdb5210b",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"5572\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label text\n",
"0 0 Go until jurong point, crazy.. Available only ...\n",
"1 0 Ok lar... Joking wif u oni...\n",
"2 1 Free entry in 2 a wkly comp to win FA Cup fina...\n",
"3 0 U dun say so early hor... U c already then say...\n",
"4 0 Nah I don't think he goes to usf, he lives aro..."
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mxdJUOsolsMj"
},
"source": [
"num_classes = 2"
],
"id": "mxdJUOsolsMj",
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:31.796699Z",
"iopub.status.busy": "2021-07-31T21:06:31.795920Z",
"iopub.status.idle": "2021-07-31T21:06:39.859089Z",
"shell.execute_reply": "2021-07-31T21:06:39.859561Z",
"shell.execute_reply.started": "2021-07-31T08:21:54.413397Z"
},
"papermill": {
"duration": 8.208197,
"end_time": "2021-07-31T21:06:39.859775",
"exception": false,
"start_time": "2021-07-31T21:06:31.651578",
"status": "completed"
},
"tags": [],
"id": "eb38df80"
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"train_split, val_split = train_test_split(train, test_size=.05)"
],
"id": "eb38df80",
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:40.148699Z",
"iopub.status.busy": "2021-07-31T21:06:40.147996Z",
"iopub.status.idle": "2021-07-31T21:06:40.188820Z",
"shell.execute_reply": "2021-07-31T21:06:40.188250Z",
"shell.execute_reply.started": "2021-07-31T08:22:01.764055Z"
},
"papermill": {
"duration": 0.187948,
"end_time": "2021-07-31T21:06:40.188979",
"exception": false,
"start_time": "2021-07-31T21:06:40.001031",
"status": "completed"
},
"tags": [],
"id": "a9bbf53d"
},
"source": [
"from transformers import BertTokenizerFast\n",
"tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')"
],
"id": "a9bbf53d",
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:40.486360Z",
"iopub.status.busy": "2021-07-31T21:06:40.485678Z",
"iopub.status.idle": "2021-07-31T21:06:51.481509Z",
"shell.execute_reply": "2021-07-31T21:06:51.480780Z",
"shell.execute_reply.started": "2021-07-31T08:22:01.807511Z"
},
"papermill": {
"duration": 11.150744,
"end_time": "2021-07-31T21:06:51.481655",
"exception": false,
"start_time": "2021-07-31T21:06:40.330911",
"status": "completed"
},
"tags": [],
"id": "f91e37b7"
},
"source": [
"import torch\n",
"\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, df, tokenizer, max_length=128):\n",
" self.df = df\n",
" self.text = df.text.values\n",
" self.labels = df.label.values\n",
" self.tokenizer = tokenizer\n",
" self.max_length = max_length\n",
" \n",
" def __getitem__(self, idx):\n",
" \n",
" tokenized_data = tokenizer.tokenize(self.text[idx])\n",
" to_append = [\"[CLS]\"] + tokenized_data[:self.max_length - 2] + [\"[SEP]\"]\n",
" input_ids = tokenizer.convert_tokens_to_ids(to_append)\n",
" input_mask = [1] * len(input_ids)\n",
" padding = [0] * (self.max_length - len(input_ids))\n",
" input_ids += padding\n",
" input_mask += padding\n",
" item = {\n",
" \"input_ids\": torch.tensor(input_ids, dtype=torch.long),\n",
" \"attention_mask\": torch.tensor(input_mask, dtype=torch.long)\n",
" }\n",
" item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)\n",
" return item\n",
" \n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
"train_dataset = Dataset(train_split.fillna(\"\"), tokenizer)\n",
"val_dataset = Dataset(val_split.fillna(\"\"), tokenizer)\n",
"# train_dataset = Dataset(train.fillna(\"\"), tokenizer, is_train=True, label_map=label_map)\n",
"# test_dataset = Dataset(test.fillna(\"\"), tokenizer, is_train=False)"
],
"id": "f91e37b7",
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:06:51.772366Z",
"iopub.status.busy": "2021-07-31T21:06:51.771623Z",
"iopub.status.idle": "2021-07-31T21:07:23.576720Z",
"shell.execute_reply": "2021-07-31T21:07:23.576103Z",
"shell.execute_reply.started": "2021-07-31T08:22:12.696992Z"
},
"papermill": {
"duration": 31.953274,
"end_time": "2021-07-31T21:07:23.576885",
"exception": false,
"start_time": "2021-07-31T21:06:51.623611",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6186779d",
"outputId": "9e0874d1-9977-4b48-9c98-3d19e5d4ccd0"
},
"source": [
"from transformers import BertForSequenceClassification, Trainer, TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir='./results', # output directory\n",
" num_train_epochs=50, # total number of training epochs\n",
" per_device_train_batch_size=64, # batch size per device during training\n",
" per_device_eval_batch_size=64, # batch size for evaluation\n",
" warmup_steps=500, # number of warmup steps for learning rate scheduler\n",
" weight_decay=0.01, # strength of weight decay\n",
" logging_dir='./logs', # directory for storing logs\n",
" logging_steps=100,\n",
" dataloader_num_workers=2,\n",
" report_to=\"tensorboard\",\n",
" label_smoothing_factor=0.1,\n",
" evaluation_strategy=\"steps\",\n",
" eval_steps=500, # Evaluation and Save happens every 500 steps\n",
" save_total_limit=3, # Only last 5 models are saved. Older ones are deleted.\n",
" load_best_model_at_end=True, #best model is always saved\n",
")\n",
"\n",
"model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\")\n",
"model.classifier = torch.nn.Linear(768, num_classes)\n",
"model.num_labels = num_classes"
],
"id": "6186779d",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"execution": {
"iopub.execute_input": "2021-07-31T21:07:23.873644Z",
"iopub.status.busy": "2021-07-31T21:07:23.872942Z",
"iopub.status.idle": "2021-08-01T03:02:53.329971Z",
"shell.execute_reply": "2021-08-01T03:02:53.330600Z"
},
"papermill": {
"duration": 21329.605812,
"end_time": "2021-08-01T03:02:53.332339",
"exception": false,
"start_time": "2021-07-31T21:07:23.726527",
"status": "completed"
},
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/",
"height": 675
},
"id": "13abf542",
"outputId": "42a52c37-29f9-4f43-a70f-c285886f0c10"
},
"source": [
"trainer = Trainer(\n",
" model=model, # the instantiated 馃 Transformers model to be trained\n",
" args=training_args, # training arguments, defined above\n",
" train_dataset=train_dataset, # training dataset\n",
" eval_dataset=val_dataset # evaluation dataset\n",
")\n",
"trainer.train()"
],
"id": "13abf542",
"execution_count": null,
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running training *****\n",
" Num examples = 5293\n",
" Num Epochs = 50\n",
" Instantaneous batch size per device = 64\n",
" Total train batch size (w. parallel, distributed & accumulation) = 64\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 4150\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='408' max='4150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 408/4150 16:23 < 2:31:08, 0.41 it/s, Epoch 4.90/50]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='1507' max='4150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [1507/4150 1:01:13 < 1:47:31, 0.41 it/s, Epoch 18.14/50]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>0.243900</td>\n",
" <td>0.221590</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>0.200800</td>\n",
" <td>0.217612</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1500</td>\n",
" <td>0.198900</td>\n",
" <td>0.218175</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"***** Running Evaluation *****\n",
" Num examples = 279\n",
" Batch size = 64\n",
"Saving model checkpoint to ./results/checkpoint-500\n",
"Configuration saved in ./results/checkpoint-500/config.json\n",
"Model weights saved in ./results/checkpoint-500/pytorch_model.bin\n",
"***** Running Evaluation *****\n",
" Num examples = 279\n",
" Batch size = 64\n",
"Saving model checkpoint to ./results/checkpoint-1000\n",
"Configuration saved in ./results/checkpoint-1000/config.json\n",
"Model weights saved in ./results/checkpoint-1000/pytorch_model.bin\n",
"***** Running Evaluation *****\n",
" Num examples = 279\n",
" Batch size = 64\n",
"Saving model checkpoint to ./results/checkpoint-1500\n",
"Configuration saved in ./results/checkpoint-1500/config.json\n",
"Model weights saved in ./results/checkpoint-1500/pytorch_model.bin\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lsoxR09ymdHv"
},
"source": [
""
],
"id": "lsoxR09ymdHv",
"execution_count": null,
"outputs": []
}
]
}