Notebook Twelve |
Repository
Convolutional Neural Network
Andrea Leone
University of Trento
January 2022
import project
import numpy as np
import torch
project.notebook()
records = project.sql_query("""
SELECT vector, category FROM talks
WHERE vector IS NOT NULL
ORDER BY slug ASC;
""")
(x, y), (z, t) \
= train_set, test_set \
= splits \
= project.split_in_sets( records )
project.describe_sets(splits)
train_set => (0, 1376) (1, 1572) (2, 1052) test_set => (0, 243) (1, 275) (2, 192)
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.cnv1 = torch.nn.Conv1d( 1, 5, kernel_size=3 )
self.cnv2 = torch.nn.Conv1d( 5, 10, kernel_size=3 )
self.cnv3 = torch.nn.Conv1d( 10, 20, kernel_size=3 )
self.cnv4 = torch.nn.Conv1d( 20, 30, kernel_size=3 )
self.cnv5 = torch.nn.Conv1d( 30, 40, kernel_size=3 )
self.bn1 = torch.nn.BatchNorm1d( 5 )
self.bn2 = torch.nn.BatchNorm1d( 10 )
self.bn3 = torch.nn.BatchNorm1d( 20 )
self.bn4 = torch.nn.BatchNorm1d( 30 )
self.bn5 = torch.nn.BatchNorm1d( 40 )
self.pool = torch.nn.MaxPool1d( kernel_size=2 )
self.drop = torch.nn.Dropout( p=0.25 )
self.fc1 = torch.nn.Linear( 480, 200 )
self.fc2 = torch.nn.Linear( 200, 100 )
self.fc3 = torch.nn.Linear( 100, 50 )
self.fc4 = torch.nn.Linear( 50, 10 )
self.fc5 = torch.nn.Linear( 10, 3 )
def forward(self, x):
x = x.reshape(1, 1, 300)
x = self.cnv1(x)
x = self.bn1(x)
x = torch.nn.functional.relu( x )
x = self.pool(x)
x = self.cnv2(x)
x = self.bn2(x)
x = torch.nn.functional.relu( x )
x = self.pool(x)
x = self.cnv3(x)
x = self.bn3(x)
x = torch.nn.functional.relu( x )
x = self.pool(x)
x = self.cnv4(x)
x = self.bn4(x)
x = torch.nn.functional.relu( x )
x = self.pool(x)
x = self.drop(x)
x = x.reshape(x.shape[0], -1)
x = torch.nn.functional.relu( self.fc1(x) )
x = torch.nn.functional.relu( self.fc2(x) )
x = torch.nn.functional.relu( self.fc3(x) )
x = torch.nn.functional.gelu( self.fc4(x) )
x = self.fc5(x)
return x
model = Network()
print(model)
criterion = torch.nn.CrossEntropyLoss(
weight=project.class_weights(y, as_type='tensor')
)
optimizer = torch.optim.AdamW (
model.parameters(),
lr=.001, eps=1e-08, weight_decay=.02
)
performance = project.train_nn ( model, x, y, criterion,optimizer, epochs=25, li=800 )
results = project.test_nn ( model, z, t )
t, p, (accuracy,precision,recall), rl = results
project.plot_train(performance)
Network( (cnv1): Conv1d(1, 5, kernel_size=(3,), stride=(1,)) (cnv2): Conv1d(5, 10, kernel_size=(3,), stride=(1,)) (cnv3): Conv1d(10, 20, kernel_size=(3,), stride=(1,)) (cnv4): Conv1d(20, 30, kernel_size=(3,), stride=(1,)) (cnv5): Conv1d(30, 40, kernel_size=(3,), stride=(1,)) (bn1): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn3): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn4): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn5): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (drop): Dropout(p=0.25, inplace=False) (fc1): Linear(in_features=480, out_features=200, bias=True) (fc2): Linear(in_features=200, out_features=100, bias=True) (fc3): Linear(in_features=100, out_features=50, bias=True) (fc4): Linear(in_features=50, out_features=10, bias=True) (fc5): Linear(in_features=10, out_features=3, bias=True) ) AdamW ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 0.001 weight_decay: 0.02 ) CrossEntropyLoss() TRAINING
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 1.041 loss 1600: 0.946 loss 2400: 0.911 loss 3200: 0.849 loss 4000: 0.840
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.793 loss 1600: 0.759 loss 2400: 0.788 loss 3200: 0.757 loss 4000: 0.750
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.712 loss 1600: 0.698 loss 2400: 0.753 loss 3200: 0.703 loss 4000: 0.721
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.695 loss 1600: 0.664 loss 2400: 0.739 loss 3200: 0.681 loss 4000: 0.673
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.689 loss 1600: 0.647 loss 2400: 0.732 loss 3200: 0.690 loss 4000: 0.690
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.671 loss 1600: 0.623 loss 2400: 0.703 loss 3200: 0.673 loss 4000: 0.674
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.648 loss 1600: 0.622 loss 2400: 0.724 loss 3200: 0.674 loss 4000: 0.672
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.647 loss 1600: 0.624 loss 2400: 0.720 loss 3200: 0.659 loss 4000: 0.682
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.636 loss 1600: 0.602 loss 2400: 0.694 loss 3200: 0.646 loss 4000: 0.660
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.626 loss 1600: 0.617 loss 2400: 0.713 loss 3200: 0.644 loss 4000: 0.639
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.635 loss 1600: 0.609 loss 2400: 0.676 loss 3200: 0.650 loss 4000: 0.651
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.631 loss 1600: 0.605 loss 2400: 0.691 loss 3200: 0.634 loss 4000: 0.633
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.620 loss 1600: 0.598 loss 2400: 0.673 loss 3200: 0.644 loss 4000: 0.641
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.601 loss 1600: 0.594 loss 2400: 0.664 loss 3200: 0.637 loss 4000: 0.629
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.601 loss 1600: 0.587 loss 2400: 0.694 loss 3200: 0.637 loss 4000: 0.629
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.610 loss 1600: 0.578 loss 2400: 0.689 loss 3200: 0.624 loss 4000: 0.622
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.612 loss 1600: 0.572 loss 2400: 0.664 loss 3200: 0.645 loss 4000: 0.609
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.598 loss 1600: 0.573 loss 2400: 0.671 loss 3200: 0.618 loss 4000: 0.621
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.604 loss 1600: 0.567 loss 2400: 0.694 loss 3200: 0.629 loss 4000: 0.623
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.598 loss 1600: 0.588 loss 2400: 0.672 loss 3200: 0.620 loss 4000: 0.627
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.578 loss 1600: 0.582 loss 2400: 0.656 loss 3200: 0.616 loss 4000: 0.621
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.608 loss 1600: 0.574 loss 2400: 0.667 loss 3200: 0.613 loss 4000: 0.629
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.599 loss 1600: 0.571 loss 2400: 0.665 loss 3200: 0.619 loss 4000: 0.609
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.582 loss 1600: 0.573 loss 2400: 0.639 loss 3200: 0.603 loss 4000: 0.613
0%| | 0/4000 [00:00<?, ?it/s]
loss 800: 0.569 loss 1600: 0.565 loss 2400: 0.646 loss 3200: 0.612 loss 4000: 0.616 TESTING
0%| | 0/710 [00:00<?, ?it/s]
accuracy 0.6985915492957746 precision 0.6974371409469863 recall 0.6829780053622646
Fine tuning score board — CNN
accuracy precision recall cm_d cl_s fcl_s es o_ps afs .72676056 .72091802 .71870643 187 208 121 5.3 10.3 10.3 5.3 200 100 50 10 3 5 lr=.001 gelu+relu .72394366 .71718998 .71655630 192 202 120 5.3 10.3 10.3 5.3 200 100 50 10 3 10 lr=.001 gelu+relu .71126760 .70832584 .69628140 176 222 107 5.3 10.3 10.3 5.3 200 100 50 10 3 5 lr=.001 gelu+gelu .73239436 .73257180 .72487365 169 222 129 5.3 10.3 10.3 5.3 200 100 50 10 3 10 lr=.001 gelu+gelu .71549295 .70872243 .70812110 188 201 119 5.3 10.3 10.3 5.3 200 100 50 10 3 15 lr=.001 gelu+gelu .71267605 .71722856 .69430446 179 227 100 5.2 10.2 10.2 5.2 100 50 3 10 lr=.001 tanh+gelu .72676056 .72785771 .71426128 169 229 118 5.2 10.2 10.2 5.2 200 100 50 10 3 10 lr=.001 tanh+gelu .70281690 .70513830 .70080722 158 206 135 5.3 10.3 10.3 5.3 200 100 50 10 3 5 lr=.001 tanh+gelu .73098591 .72889928 .71944873 182 220 117 5.3 10.3 10.3 5.3 200 100 50 10 3 10 lr=.001 tanh+gelu .73098591 .72959104 .71865062 177 225 117 5.3 10.3 10.3 5.3 200 100 50 10 3 5 lr=.001 tanh+relu .73098591 .73548818 .71623542 175 231 113 5.3 10.3 10.3 5.3 200 100 50 10 3 10 lr=.001 tanh+relu .69436619 .70395044 .69342000 154 203 136 5.1 10.2 10.2 5.3 200 100 50 10 3 5 lr=.001 tanh+gelu .71971830 .72140926 .71615078 163 213 135 5.1 10.2 10.2 5.3 200 100 50 10 3 10 lr=.001 tanh+gelu .69295774 .68948801 .68216181 160 217 115 10.3 10.3 10.3 10.3 200 100 50 10 3 5 lr=.001 tanh+relu .70985915 .70660968 .69600277 172 222 110 10.2 10.2 10.3 10.3 200 100 50 10 3 5 lr=.001 tanh+relu .70140845 .70538124 .69153050 160 219 119 5.3 10.3 10.3 10.3 200 100 50 10 3 5 lr=.001 tanh+relu .73943661 .73941735 .73797512 177 208 140 5.3 10.3 10.3 10.3 200 100 50 10 3 10 lr=.001 tanh+relu .72394366 .71885208 .71423135 184 212 118 5.3 10.3 10.3 10.3 200 100 50 10 3 15 lr=.001 tanh+relu .72253521 .72036646 .71178739 173 221 119 5.3 10.3 10.3 5.3 5.1 200 100 50 10 3 5 lr=.001 tanh+relu .71971830 .71461776 .70745105 184 215 112 5.3 10.3 10.3 5.3 5.1 200 100 50 10 3 10 lr=.001 tanh+relu