Siamese networksΒΆ
Python source code: siamese_mnist.py
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import deeppy as dp
# Fetch MNIST data
dataset = dp.dataset.MNIST()
x_train, y_train, x_test, y_test = dataset.data(flat=True, dp_dtypes=True)
# Normalize pixel intensities
scaler = dp.StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
# Generate image pairs
n_pairs = 100000
x1 = np.empty((n_pairs, 28*28), dtype=dp.float_)
x2 = np.empty_like(x1, dtype=dp.float_)
y = np.empty(n_pairs, dtype=dp.int_)
n_imgs = x_train.shape[0]
n = 0
while n < n_pairs:
i = random.randint(0, n_imgs-1)
j = random.randint(0, n_imgs-1)
if i == j:
continue
x1[n, ...] = x_train[i]
x2[n, ...] = x_train[j]
if y_train[i] == y_train[j]:
y[n] = 1
else:
y[n] = 0
n += 1
# Prepare network inputs
train_input = dp.SupervisedSiameseInput(x1, x2, y, batch_size=128)
# Setup network
w_gain = 1.5
w_decay = 1e-4
net = dp.SiameseNetwork(
siamese_layers=[
dp.FullyConnected(
n_out=1024,
weights=dp.Parameter(dp.AutoFiller(w_gain), weight_decay=w_decay),
),
dp.Activation('relu'),
dp.FullyConnected(
n_out=1024,
weights=dp.Parameter(dp.AutoFiller(w_gain), weight_decay=w_decay),
),
dp.Activation('relu'),
dp.FullyConnected(
n_out=2,
weights=dp.Parameter(dp.AutoFiller(w_gain)),
),
],
loss=dp.ContrastiveLoss(margin=1.0),
)
# Train network
trainer = dp.StochasticGradientDescent(
max_epochs=15,
learn_rule=dp.RMSProp(learn_rate=0.01),
)
trainer.train(net, train_input)
# Plot 2D embedding
test_input = dp.Input(x_test)
x_test = np.reshape(x_test, (-1,) + dataset.img_shape)
feat = net.features(test_input)
feat -= np.min(feat, 0)
feat /= np.max(feat, 0)
plt.figure()
ax = plt.subplot(111)
shown_images = np.array([[1., 1.]])
for i in range(feat.shape[0]):
dist = np.sum((feat[i] - shown_images)**2, 1)
if np.min(dist) < 6e-4:
# don't show points that are too close
continue
shown_images = np.r_[shown_images, [feat[i]]]
imagebox = offsetbox.AnnotationBbox(
offsetbox.OffsetImage(x_test[i], zoom=0.6, cmap=plt.cm.gray_r),
xy=feat[i], frameon=False
)
ax.add_artist(imagebox)
plt.xticks([]), plt.yticks([])
plt.title('Embedding from the last layer of the network')
Total running time of the example: 1 minutes 41.4 seconds