#!/usr/bin/env python3

# restricted Boltzmann machine; source:
# https://blog.paperspace.com/beginners-guide-to-boltzmann-machines-pytorch/

# torchvision datasets
# https://pytorch.org/vision/stable/datasets.html

import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image
import matplotlib.pyplot as plt

# loading the MNIST dataset
# 'train_loader' is an instance of a DataLoader
#  'test_loader' for testing (currently not used)

batch_size = 64                           # samples per epoch 
train_loader = torch.utils.data.DataLoader(
  datasets.MNIST(                         # which dataset to load
    './data',                             # store in local directory
    train=True,
    download = True,                      # do download in dir
    transform = transforms.Compose([transforms.ToTensor()])
                ), 
  batch_size=batch_size                   )

test_loader = torch.utils.data.DataLoader(
  datasets.MNIST('./data', train=False,
    transform=transforms.Compose([transforms.ToTensor()])
                ), batch_size=batch_size )

#
# defining the restricted Boltzmann machine
# 'vis' : visible
# 'hin' : hidden
#
class RBM(nn.Module):
   def __init__(self, n_vis=784, n_hin=500, k=5):
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(n_hin,n_vis)*1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k                        # iteration depth

   def sample_from_p(self, p):            # p -> 0/1  stochastically
       return F.relu(torch.sign(p-torch.rand(p.size())))

   def v_to_h(self, v):
        p_h = F.sigmoid(F.linear(v, self.W, self.h_bias))  # update hidden
        sample_h = self.sample_from_p(p_h)                 # 0/1 sample
        return p_h, sample_h

   def h_to_v(self, h):                   # transpose W for other direction
        p_v = F.sigmoid(F.linear(h, self.W.t(), self.v_bias))
        sample_v = self.sample_from_p(p_v)
        return p_v, sample_v

   def forward(self, v):
        pre_h1, h1 = self.v_to_h(v)
        h_ = h1
        for _ in range(self.k):           # consistency loop
            pre_v_, v_ = self.h_to_v(h_)  # with 0/1 samples
            pre_h_, h_ = self.v_to_h(v_)
        return v, v_                      # return input, 0/1 reconstruction

   def free_energy(self, v):
        """hidden term: sum over hidden units
           1+exp(): s_hidden = 0/1 
           wx_b   : energy for hidden units
           v: visible activity, data / reconstructed
            : only one term -->  -log(exp(-E)) = E (modulo sign)
        """
        vbias_term  = v.mv(self.v_bias)
        wx_b        = F.linear(v, self.W, self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()

#
# define model, register optimizer
# SGD: stochatic gradient descent
#
model = RBM(k=1)
train_op = optim.SGD(model.parameters(),0.1) 

#
# training model
#
for epoch in range(2):
    loss_ = []
    for _, (data,target) in enumerate(train_loader):
        data = data.view(-1, 784)
        sample_data = data.bernoulli()

        v, v1 = model(sample_data)
        loss = model.free_energy(v) - model.free_energy(v1)
        loss_.append(loss.data)
        train_op.zero_grad()
        loss.backward()
        train_op.step()

    print("Training loss for {} epoch: {}".format(epoch, np.mean(loss_)))

#
# storing images
#
def show_and_save(file_id,img,fig,position):
    npimg = np.transpose(img.numpy(),(1,2,0))
    fileName = "RBM_out_" + file_id + ".png"
    fig.add_subplot(1, 2, position)
    plt.title(file_id)
    plt.axis('off')
    plt.imshow(npimg)
    plt.imsave(fileName,npimg)

#
# visualising training outputs
#
fig = plt.figure(figsize=(12, 6))
show_and_save("real"     ,make_grid( v.view(32,1,28,28).data),fig,1)
show_and_save("generated",make_grid(v1.view(32,1,28,28).data),fig,2)
plt.show()


