We had created a R notebook version of the first portion of movielens python notebook from the Fastai Deep Learning for Coders (Part 1) where high level fastai functions were used to build and fit the model. This notebook tries to create the R version of the second portion of the movielens python notebook where Jeremy creates the collaborative filtering model form scratch.

This content is covered in videos of lecture 5 and lecture 6. It will be helpful to listen to the lectures before going through this notebook since the concepts of the model and approach are discussed in the lecture and this notebook is just a replication attempt of the material from the course using R.

Initial Setup

# import R libraries
library(reticulate)
library(ggplot2)
library(dplyr)

Attaching package: ‘dplyr’

The following objects are masked from ‘package:stats’:

    filter, lag

The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union

Notes about python setup and machine used are covered in this R notebook.

use_python("/home/paperspace/anaconda3/envs/fastai/bin/python", required = TRUE)
use_condaenv("fastai")
py_config()
python:         /home/paperspace/anaconda3/envs/fastai/bin/python
libpython:      /home/paperspace/anaconda3/envs/fastai/lib/libpython3.6m.so
pythonhome:     /home/paperspace/anaconda3/envs/fastai:/home/paperspace/anaconda3/envs/fastai
version:        3.6.4 |Anaconda, Inc.| (default, Dec 21 2017, 21:42:08)  [GCC 7.2.0]
numpy:          /home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/numpy
numpy_version:  1.13.3

NOTE: Python version was forced by use_python function
main = import_main()
bi = import_builtins()
# get relevant python imports
fstai_learner = import_from_path("fastai.learner", "../../fastai")
fstai_coldata = import_from_path("fastai.column_data", "../../fastai")
py_run_string("
from fastai.learner import *
from fastai.column_data import *
              ")

Get Data

The ratings dataset has the ratings for different users and movies. The movies dataset has the movie title information.

datapath = "../../data/ml-latest-small/"
ratings = read.csv(paste0(datapath, "ratings.csv"), stringsAsFactors = FALSE)
head(ratings)
movies = read.csv(paste0(datapath, "movies.csv"), stringsAsFactors = FALSE)
head(movies)

Working with PyTorch Tensors

Amat = matrix(c(1.0, 2.0, 3.0, 4.0), byrow = TRUE, ncol = 2)
Bmat = matrix(c(2.0, 2.0, 10.0, 10.0), byrow = TRUE, ncol = 2)
a = py$T(Amat)
a

 1  2
 3  4
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]
b = py$T(Bmat)
b

  2   2
 10  10
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]

In python when we use a*b, it works fine but won’t work when calling from R since R doesn’t know how to use * operator with Torch tensors. So here I have used the pyTorch mul function for multiplication

a$mul(b)

  2   4
 30  40
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]
a$mul(b)$sum(1L) # note the use of 1L instead of 1 to ensure that we are passing integer

  6
 70
[torch.cuda.FloatTensor of size 2 (GPU 0)]

In PyTorch, we need to define a model as a Python class that inherits from nn.Module. It has a specific method forward which gives the recipe for computing the prediction given inputs. For example, if the prediction r is a dot product of input vectors u and m, then it would be defined in the following class.

py_run_string("
class DotProduct(nn.Module):
    def forward(self, u, m): return (u*m).sum(1)              
              ")

A model can then be defined as the instance of the class and calling the forward method with inputs with give the prediction.

model = main$DotProduct()
model$forward(a, b)

  6
 70
[torch.cuda.FloatTensor of size 2 (GPU 0)]

Data Preparation for Use in PyTorch Model

The ratings dataset has userId and movieId fields. But for passing to pyTorch, we create a sequential index for both.

u_unique = unique(ratings$userId)
user2idx = as.integer(seq(0, length(u_unique) - 1))
names(user2idx) = u_unique
user2idx[1:10]
 1  2  3  4  5  6  7  8  9 10 
 0  1  2  3  4  5  6  7  8  9 
m_unique = unique(ratings$movieId)
movie2idx = as.integer(seq(0, length(m_unique) - 1))
names(movie2idx) = m_unique
movie2idx[1:10]
  31 1029 1061 1129 1172 1263 1287 1293 1339 1343 
   0    1    2    3    4    5    6    7    8    9 
ratings$userIdx = user2idx[as.character(ratings$userId)]
ratings$movieIdx = movie2idx[as.character(ratings$movieId)]
n_users = length(u_unique)
n_movies = length(m_unique)
n_users; n_movies
[1] 671
[1] 9066

Model 1

Each user is \(i\) represent by an embedding vector \(u_i\) consisting of n_factor values. Similarly a movie \(j\) is represented by an embedding vector \(m_j\) consisting of n_factor values. The model of rating \(r_{ij}\) given by user \(i\) to movie \(j\) is: \[ r_{ij} = u_i^Tv_j\] The class below defines the model. The constructor for class passes other inputs and initializes model parameters.

py_run_string("
class EmbeddingDot(nn.Module):
    def __init__(self, n_users, n_movies, n_factors):
        super().__init__()
        self.u = nn.Embedding(n_users, n_factors)
        self.m = nn.Embedding(n_movies, n_factors)
        self.u.weight.data.uniform_(0,0.05)
        self.m.weight.data.uniform_(0,0.05)
        
    def forward(self, cats, conts):
        users,movies = cats[:,0],cats[:,1]
        u,m = self.u(users),self.m(movies)
        return (u*m).sum(1)              
              ")

The x dataframe just includes the sequential user id and movie id. y is a numpy array of ratings with type float32.

x = ratings[, c("userIdx", "movieIdx")]
y = np_array(ratings$rating)$astype(py$np$float32)

The list of validation data rows are selected and n_factors is set to 50.

val_idxs = py$get_cv_idxs(nrow(ratings))
val_idxs = as.integer(val_idxs)
n_factors = 50L

The data loader object is created

data = py$ColumnarModelData$from_data_frame(datapath, val_idxs, r_to_py(x), y, c("userIdx", "movieIdx"), 64L)

Model is defined using the defined class EmbeddingDot. Based on the model, the optimizer opt is defined.

wd=1e-5
model = py$EmbeddingDot(n_users, n_movies, n_factors)$cuda()
opt = py$optim$SGD(model$parameters(), 1e-1, weight_decay=wd, momentum=0.9)

The model details are listed below

model
EmbeddingDot(
  (u): Embedding(671, 50)
  (m): Embedding(9066, 50)
)

Model is fit. In Jupyter notebook, a widget shows a nice output of progress. That output doesn’t render properly within RStudio and the html generated document has too much output. For now, I have turned off output and am just storing the final validation loss metric (MSE loss in this case)

mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 1.219"

Model run for some more epochs

py$set_lrs(opt, 0.01)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 1.131"

Alternate way to check validation MSE is to explicitly get predictions from model for validation data and compare to validation data

yval_preds = py$predict(model, data$val_dl)
yval=data$val_y[,1]
mse_loss = mean((yval_preds - yval)**2)
mse_loss
[1] 1.131446

Model 2

This is model 1 with added bias term for users and movies. Each user is \(i\) represent by an embedding vector \(u_i\) consisting of n_factor values and a user bias value \(ub_i\). Similarly a movie \(j\) is represented by an embedding vector \(m_j\) consisting of n_factor values and movie bias value \(mb_j\). In addtion, the output is contrained to be between minimum rating r_min and maximum rating r_max using a sigmoid function. The model of rating \(r_{ij}\) given by user \(i\) to movie \(j\) is: \[ r_{ij} = \frac{e^{u_i^Tv_j + b_i + m_j}}{e^{u_i^Tv_j + b_i + m_j} + 1}(r_{max} - r_{min}) + r_{min} \]

py_run_string("
def get_emb(ni,nf):
    e = nn.Embedding(ni, nf)
    e.weight.data.uniform_(-0.01,0.01)
    return e
              
class EmbeddingDotBias(nn.Module):
    def __init__(self, n_users, n_movies, n_factors, min_rating, max_rating):
        super().__init__()
        (self.u, self.m, self.ub, self.mb) = [get_emb(*o) for o in [
              (n_users, n_factors), (n_movies, n_factors), (n_users,1), (n_movies,1)
              ]]
        self.max_rating = max_rating
        self.min_rating = min_rating
              
    def forward(self, cats, conts):
        users,movies = cats[:,0],cats[:,1]
        um = (self.u(users)* self.m(movies)).sum(1)
        res = um + self.ub(users).squeeze() + self.mb(movies).squeeze()
        res = F.sigmoid(res) * (self.max_rating-self.min_rating) + self.min_rating
        return res              
              ")

The model object opt is defined and fit

wd=2e-4
min_rating = min(ratings$rating)
max_rating = max(ratings$rating)
model = py$EmbeddingDotBias(n_users, n_movies, n_factors, min_rating, max_rating)$cuda()
opt = py$optim$SGD(model$parameters(), 1e-1, weight_decay=wd, momentum=0.9)

The model details are listed below

model
EmbeddingDotBias(
  (u): Embedding(671, 50)
  (m): Embedding(9066, 50)
  (ub): Embedding(671, 1)
  (mb): Embedding(9066, 1)
)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.807"

Change the learning rate and refit

py$set_lrs(opt, 0.01)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.801"

Model 3

This model is a deep learning model with the following layers:

  1. Input - Concatenation of user embedding vector and movie embedding vector (size = 2*n_factors)
  2. Dropout with dropout rate p1=0.05
  3. Linear fully connected layer with output size nh = 10
  4. Relu
  5. Dropout with dropout rate p2=0.5
  6. Linear fully connected layer with output size 1.
  7. Apply sigmoid function and scale to be between min and max rating
py_run_string("
class EmbeddingNet(nn.Module):
    def __init__(self, n_users, n_movies, n_factors, min_rating, max_rating, nh=10, p1=0.05, p2=0.5):
        super().__init__()
        (self.u, self.m) = [get_emb(*o) for o in [
            (n_users, n_factors), (n_movies, n_factors)]]
        self.lin1 = nn.Linear(n_factors*2, nh)
        self.lin2 = nn.Linear(nh, 1)
        self.drop1 = nn.Dropout(p1)
        self.drop2 = nn.Dropout(p2)
        self.min_rating = min_rating
        self.max_rating = max_rating
        
    def forward(self, cats, conts):
        users,movies = cats[:,0],cats[:,1]
        x = self.drop1(torch.cat([self.u(users),self.m(movies)], dim=1))
        x = self.drop2(F.relu(self.lin1(x)))
        return F.sigmoid(self.lin2(x)) * (self.max_rating-self.min_rating+1) + self.min_rating-0.5
           ")

Model is defined and fit.

wd=1e-5
model = py$EmbeddingNet(n_users, n_movies, n_factors, min_rating, max_rating)$cuda()
opt = py$optim$Adam(model$parameters(), 1e-3, weight_decay=wd)

The model details are listed below

model
EmbeddingNet(
  (u): Embedding(671, 50)
  (m): Embedding(9066, 50)
  (lin1): Linear(in_features=100, out_features=10)
  (lin2): Linear(in_features=10, out_features=1)
  (drop1): Dropout(p=0.05)
  (drop2): Dropout(p=0.5)
)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.784"

Learning rate is changed and model refit

py$set_lrs(opt, 0.01)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.828"

Summary

This shows how a model can be developed from scratch. But the model would need to be defined as python class. I am guessing if somebody is developing a model from scratch, it might be better to do it just in python and create wrapper functions which can then by used by R.

sessionInfo()
R version 3.4.3 (2017-11-30)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 16.04.3 LTS

Matrix products: default
BLAS: /usr/lib/libblas/libblas.so.3.6.0
LAPACK: /home/paperspace/anaconda3/envs/fastai/lib/libmkl_intel_lp64.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8     LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] dplyr_0.7.4    ggplot2_2.2.1  reticulate_1.6

loaded via a namespace (and not attached):
 [1] Rcpp_0.12.14     assertthat_0.2.0 R6_2.2.2         grid_3.4.3       plyr_1.8.4      
 [6] jsonlite_1.5     gtable_0.2.0     magrittr_1.5     scales_0.5.0     pillar_1.1.0    
[11] rlang_0.1.6      lazyeval_0.2.1   bindrcpp_0.2     tools_3.4.3      glue_1.2.0      
[16] munsell_0.4.3    yaml_2.1.18      compiler_3.4.3   pkgconfig_2.0.1  colorspace_1.3-2
[21] bindr_0.1.1      knitr_1.20       tibble_1.4.2    
LS0tCnRpdGxlOiAiRmFzdGFpIENvbGxhYm9yYXRpdmUgRmlsdGVyaW5nIChmcm9tIFNjcmF0Y2gpIHdpdGggUiBhbmQgUmV0aWN1bGF0ZSIKb3V0cHV0OiBodG1sX25vdGVib29rCmVkaXRvcl9vcHRpb25zOiAKICBjaHVua19vdXRwdXRfdHlwZTogaW5saW5lCi0tLQoKV2UgaGFkIGNyZWF0ZWQgYSBbUiBub3RlYm9va10oaHR0cHM6Ly9ub3Rlc29mZGFiYmxlci5naXRodWIuaW8vZmFzdGFpX2RsMV93aXRoUi9tb3ZpZUxlbnMubmIuaHRtbCkgdmVyc2lvbiBvZiB0aGUgZmlyc3QgcG9ydGlvbiBvZiBtb3ZpZWxlbnMgW3B5dGhvbiBub3RlYm9va10oaHR0cHM6Ly9naXRodWIuY29tL2Zhc3RhaS9mYXN0YWkvYmxvYi9tYXN0ZXIvY291cnNlcy9kbDEvbGVzc29uNS1tb3ZpZWxlbnMuaXB5bmIpIGZyb20gdGhlIFtGYXN0YWkgRGVlcCBMZWFybmluZyBmb3IgQ29kZXJzIChQYXJ0IDEpXShodHRwOi8vY291cnNlLmZhc3QuYWkvKSB3aGVyZSBoaWdoIGxldmVsIGZhc3RhaSBmdW5jdGlvbnMgd2VyZSB1c2VkIHRvIGJ1aWxkIGFuZCBmaXQgdGhlIG1vZGVsLiBUaGlzIG5vdGVib29rIHRyaWVzIHRvIGNyZWF0ZSB0aGUgUiB2ZXJzaW9uIG9mIHRoZSBzZWNvbmQgcG9ydGlvbiBvZiB0aGUgbW92aWVsZW5zIHB5dGhvbiBub3RlYm9vayB3aGVyZSBKZXJlbXkgY3JlYXRlcyB0aGUgY29sbGFib3JhdGl2ZSBmaWx0ZXJpbmcgbW9kZWwgZm9ybSBzY3JhdGNoLiAKClRoaXMgY29udGVudCBpcyBjb3ZlcmVkIGluIHZpZGVvcyBvZiBbbGVjdHVyZSA1XShodHRwOi8vY291cnNlLmZhc3QuYWkvbGVzc29ucy9sZXNzb241Lmh0bWwpIGFuZCBbbGVjdHVyZSA2XShodHRwOi8vY291cnNlLmZhc3QuYWkvbGVzc29ucy9sZXNzb242Lmh0bWwpLiBJdCB3aWxsIGJlIGhlbHBmdWwgdG8gbGlzdGVuIHRvIHRoZSBsZWN0dXJlcyBiZWZvcmUgZ29pbmcgdGhyb3VnaCB0aGlzIG5vdGVib29rIHNpbmNlIHRoZSBjb25jZXB0cyBvZiB0aGUgbW9kZWwgYW5kIGFwcHJvYWNoIGFyZSBkaXNjdXNzZWQgaW4gdGhlIGxlY3R1cmUgYW5kIHRoaXMgbm90ZWJvb2sgaXMganVzdCBhIHJlcGxpY2F0aW9uIGF0dGVtcHQgb2YgdGhlIG1hdGVyaWFsIGZyb20gdGhlIGNvdXJzZSB1c2luZyBSLgoKIyMgSW5pdGlhbCBTZXR1cAoKYGBge3J9CiMgaW1wb3J0IFIgbGlicmFyaWVzCmxpYnJhcnkocmV0aWN1bGF0ZSkKbGlicmFyeShnZ3Bsb3QyKQpsaWJyYXJ5KGRwbHlyKQoKYGBgCgpOb3RlcyBhYm91dCBweXRob24gc2V0dXAgYW5kIG1hY2hpbmUgdXNlZCBhcmUgY292ZXJlZCBpbiB0aGlzIFtSIG5vdGVib29rXShodHRwczovL25vdGVzb2ZkYWJibGVyLmdpdGh1Yi5pby9mYXN0YWlfZGwxX3dpdGhSL21vdmllTGVucy5uYi5odG1sKS4KCmBgYHtyfQoKdXNlX3B5dGhvbigiL2hvbWUvcGFwZXJzcGFjZS9hbmFjb25kYTMvZW52cy9mYXN0YWkvYmluL3B5dGhvbiIsIHJlcXVpcmVkID0gVFJVRSkKdXNlX2NvbmRhZW52KCJmYXN0YWkiKQpweV9jb25maWcoKQoKbWFpbiA9IGltcG9ydF9tYWluKCkKYmkgPSBpbXBvcnRfYnVpbHRpbnMoKQpgYGAKCmBgYHtyfQojIGdldCByZWxldmFudCBweXRob24gaW1wb3J0cwpmc3RhaV9sZWFybmVyID0gaW1wb3J0X2Zyb21fcGF0aCgiZmFzdGFpLmxlYXJuZXIiLCAiLi4vLi4vZmFzdGFpIikKZnN0YWlfY29sZGF0YSA9IGltcG9ydF9mcm9tX3BhdGgoImZhc3RhaS5jb2x1bW5fZGF0YSIsICIuLi8uLi9mYXN0YWkiKQpgYGAKCmBgYHtyfQpweV9ydW5fc3RyaW5nKCIKZnJvbSBmYXN0YWkubGVhcm5lciBpbXBvcnQgKgpmcm9tIGZhc3RhaS5jb2x1bW5fZGF0YSBpbXBvcnQgKgogICAgICAgICAgICAgICIpCmBgYAoKIyMgR2V0IERhdGEKClRoZSByYXRpbmdzIGRhdGFzZXQgaGFzIHRoZSByYXRpbmdzIGZvciBkaWZmZXJlbnQgdXNlcnMgYW5kIG1vdmllcy4gVGhlIG1vdmllcyBkYXRhc2V0IGhhcyB0aGUgbW92aWUgdGl0bGUgaW5mb3JtYXRpb24uCgpgYGB7cn0KZGF0YXBhdGggPSAiLi4vLi4vZGF0YS9tbC1sYXRlc3Qtc21hbGwvIgoKcmF0aW5ncyA9IHJlYWQuY3N2KHBhc3RlMChkYXRhcGF0aCwgInJhdGluZ3MuY3N2IiksIHN0cmluZ3NBc0ZhY3RvcnMgPSBGQUxTRSkKaGVhZChyYXRpbmdzKQoKbW92aWVzID0gcmVhZC5jc3YocGFzdGUwKGRhdGFwYXRoLCAibW92aWVzLmNzdiIpLCBzdHJpbmdzQXNGYWN0b3JzID0gRkFMU0UpCmhlYWQobW92aWVzKQpgYGAKCiMjIFdvcmtpbmcgd2l0aCBQeVRvcmNoIFRlbnNvcnMKYGBge3J9CkFtYXQgPSBtYXRyaXgoYygxLjAsIDIuMCwgMy4wLCA0LjApLCBieXJvdyA9IFRSVUUsIG5jb2wgPSAyKQpCbWF0ID0gbWF0cml4KGMoMi4wLCAyLjAsIDEwLjAsIDEwLjApLCBieXJvdyA9IFRSVUUsIG5jb2wgPSAyKQoKYSA9IHB5JFQoQW1hdCkKYQoKYiA9IHB5JFQoQm1hdCkKYgpgYGAKSW4gcHl0aG9uIHdoZW4gd2UgdXNlIGBhKmJgLCBpdCB3b3JrcyBmaW5lIGJ1dCB3b24ndCB3b3JrIHdoZW4gY2FsbGluZyBmcm9tIFIgc2luY2UgUiBkb2Vzbid0IGtub3cgaG93IHRvIHVzZSBgKmAgb3BlcmF0b3Igd2l0aCBUb3JjaCB0ZW5zb3JzLiBTbyBoZXJlIEkgaGF2ZSB1c2VkIHRoZSBweVRvcmNoIGBtdWxgIGZ1bmN0aW9uIGZvciBtdWx0aXBsaWNhdGlvbgpgYGB7cn0KYSRtdWwoYikKYSRtdWwoYikkc3VtKDFMKSAjIG5vdGUgdGhlIHVzZSBvZiAxTCBpbnN0ZWFkIG9mIDEgdG8gZW5zdXJlIHRoYXQgd2UgYXJlIHBhc3NpbmcgaW50ZWdlcgpgYGAKCkluIFB5VG9yY2gsIHdlIG5lZWQgdG8gZGVmaW5lIGEgbW9kZWwgYXMgYSBQeXRob24gY2xhc3MgdGhhdCBpbmhlcml0cyBmcm9tIGBubi5Nb2R1bGVgLiBJdCBoYXMgYSBzcGVjaWZpYyBtZXRob2QgYGZvcndhcmRgIHdoaWNoIGdpdmVzIHRoZSByZWNpcGUgZm9yIGNvbXB1dGluZyB0aGUgcHJlZGljdGlvbiBnaXZlbiBpbnB1dHMuIEZvciBleGFtcGxlLCBpZiB0aGUgcHJlZGljdGlvbiBgcmAgaXMgYSBkb3QgcHJvZHVjdCBvZiBpbnB1dCB2ZWN0b3JzIGB1YCBhbmQgYG1gLCB0aGVuIGl0IHdvdWxkIGJlIGRlZmluZWQgaW4gdGhlIGZvbGxvd2luZyBjbGFzcy4KYGBge3J9CnB5X3J1bl9zdHJpbmcoIgpjbGFzcyBEb3RQcm9kdWN0KG5uLk1vZHVsZSk6CiAgICBkZWYgZm9yd2FyZChzZWxmLCB1LCBtKTogcmV0dXJuICh1Km0pLnN1bSgxKSAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgIikKYGBgCkEgbW9kZWwgY2FuIHRoZW4gYmUgZGVmaW5lZCBhcyB0aGUgaW5zdGFuY2Ugb2YgdGhlIGNsYXNzIGFuZCBjYWxsaW5nIHRoZSBmb3J3YXJkIG1ldGhvZCB3aXRoIGlucHV0cyB3aXRoIGdpdmUgdGhlIHByZWRpY3Rpb24uCmBgYHtyfQptb2RlbCA9IG1haW4kRG90UHJvZHVjdCgpCm1vZGVsJGZvcndhcmQoYSwgYikKYGBgCgojIyBEYXRhIFByZXBhcmF0aW9uIGZvciBVc2UgaW4gUHlUb3JjaCBNb2RlbAoKVGhlIHJhdGluZ3MgZGF0YXNldCBoYXMgdXNlcklkIGFuZCBtb3ZpZUlkIGZpZWxkcy4gQnV0IGZvciBwYXNzaW5nIHRvIHB5VG9yY2gsIHdlIGNyZWF0ZSBhIHNlcXVlbnRpYWwgaW5kZXggZm9yIGJvdGguCgpgYGB7cn0KdV91bmlxdWUgPSB1bmlxdWUocmF0aW5ncyR1c2VySWQpCnVzZXIyaWR4ID0gYXMuaW50ZWdlcihzZXEoMCwgbGVuZ3RoKHVfdW5pcXVlKSAtIDEpKQpuYW1lcyh1c2VyMmlkeCkgPSB1X3VuaXF1ZQp1c2VyMmlkeFsxOjEwXQoKbV91bmlxdWUgPSB1bmlxdWUocmF0aW5ncyRtb3ZpZUlkKQptb3ZpZTJpZHggPSBhcy5pbnRlZ2VyKHNlcSgwLCBsZW5ndGgobV91bmlxdWUpIC0gMSkpCm5hbWVzKG1vdmllMmlkeCkgPSBtX3VuaXF1ZQptb3ZpZTJpZHhbMToxMF0KYGBgCgpgYGB7cn0KcmF0aW5ncyR1c2VySWR4ID0gdXNlcjJpZHhbYXMuY2hhcmFjdGVyKHJhdGluZ3MkdXNlcklkKV0KcmF0aW5ncyRtb3ZpZUlkeCA9IG1vdmllMmlkeFthcy5jaGFyYWN0ZXIocmF0aW5ncyRtb3ZpZUlkKV0KCm5fdXNlcnMgPSBsZW5ndGgodV91bmlxdWUpCm5fbW92aWVzID0gbGVuZ3RoKG1fdW5pcXVlKQoKbl91c2Vyczsgbl9tb3ZpZXMKYGBgCgojIyBNb2RlbCAxCgpFYWNoIHVzZXIgaXMgJGkkIHJlcHJlc2VudCBieSBhbiBlbWJlZGRpbmcgdmVjdG9yICR1X2kkIGNvbnNpc3Rpbmcgb2YgYG5fZmFjdG9yYCB2YWx1ZXMuIFNpbWlsYXJseSBhIG1vdmllICRqJCBpcyByZXByZXNlbnRlZCBieSBhbiBlbWJlZGRpbmcgdmVjdG9yICRtX2okIGNvbnNpc3Rpbmcgb2YgYG5fZmFjdG9yYCB2YWx1ZXMuIFRoZSBtb2RlbCBvZiByYXRpbmcgJHJfe2lqfSQgZ2l2ZW4gYnkgdXNlciAkaSQgdG8gbW92aWUgJGokIGlzOgokJCByX3tpan0gPSB1X2leVHZfaiQkClRoZSBjbGFzcyBiZWxvdyBkZWZpbmVzIHRoZSBtb2RlbC4gVGhlIGNvbnN0cnVjdG9yIGZvciBjbGFzcyBwYXNzZXMgb3RoZXIgaW5wdXRzIGFuZCBpbml0aWFsaXplcyBtb2RlbCBwYXJhbWV0ZXJzLiAKYGBge3J9CnB5X3J1bl9zdHJpbmcoIgpjbGFzcyBFbWJlZGRpbmdEb3Qobm4uTW9kdWxlKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBuX3VzZXJzLCBuX21vdmllcywgbl9mYWN0b3JzKToKICAgICAgICBzdXBlcigpLl9faW5pdF9fKCkKICAgICAgICBzZWxmLnUgPSBubi5FbWJlZGRpbmcobl91c2Vycywgbl9mYWN0b3JzKQogICAgICAgIHNlbGYubSA9IG5uLkVtYmVkZGluZyhuX21vdmllcywgbl9mYWN0b3JzKQogICAgICAgIHNlbGYudS53ZWlnaHQuZGF0YS51bmlmb3JtXygwLDAuMDUpCiAgICAgICAgc2VsZi5tLndlaWdodC5kYXRhLnVuaWZvcm1fKDAsMC4wNSkKICAgICAgICAKICAgIGRlZiBmb3J3YXJkKHNlbGYsIGNhdHMsIGNvbnRzKToKICAgICAgICB1c2Vycyxtb3ZpZXMgPSBjYXRzWzosMF0sY2F0c1s6LDFdCiAgICAgICAgdSxtID0gc2VsZi51KHVzZXJzKSxzZWxmLm0obW92aWVzKQogICAgICAgIHJldHVybiAodSptKS5zdW0oMSkgICAgICAgICAgICAgIAogICAgICAgICAgICAgICIpCmBgYAoKVGhlIHggZGF0YWZyYW1lIGp1c3QgaW5jbHVkZXMgdGhlIHNlcXVlbnRpYWwgdXNlciBpZCBhbmQgbW92aWUgaWQuIHkgaXMgYSBudW1weSBhcnJheSBvZiByYXRpbmdzIHdpdGggdHlwZSBmbG9hdDMyLgpgYGB7cn0KeCA9IHJhdGluZ3NbLCBjKCJ1c2VySWR4IiwgIm1vdmllSWR4IildCnkgPSBucF9hcnJheShyYXRpbmdzJHJhdGluZykkYXN0eXBlKHB5JG5wJGZsb2F0MzIpCmBgYAoKVGhlIGxpc3Qgb2YgdmFsaWRhdGlvbiBkYXRhIHJvd3MgYXJlIHNlbGVjdGVkIGFuZCBgbl9mYWN0b3JzYCBpcyBzZXQgdG8gNTAuCmBgYHtyfQp2YWxfaWR4cyA9IHB5JGdldF9jdl9pZHhzKG5yb3cocmF0aW5ncykpCnZhbF9pZHhzID0gYXMuaW50ZWdlcih2YWxfaWR4cykKbl9mYWN0b3JzID0gNTBMCmBgYAoKVGhlIGRhdGEgbG9hZGVyIG9iamVjdCBpcyBjcmVhdGVkCmBgYHtyfQpkYXRhID0gcHkkQ29sdW1uYXJNb2RlbERhdGEkZnJvbV9kYXRhX2ZyYW1lKGRhdGFwYXRoLCB2YWxfaWR4cywgcl90b19weSh4KSwgeSwgYygidXNlcklkeCIsICJtb3ZpZUlkeCIpLCA2NEwpCmBgYApNb2RlbCBpcyBkZWZpbmVkIHVzaW5nIHRoZSBkZWZpbmVkIGNsYXNzIGBFbWJlZGRpbmdEb3RgLiBCYXNlZCBvbiB0aGUgbW9kZWwsIHRoZSBvcHRpbWl6ZXIgYG9wdGAgaXMgZGVmaW5lZC4KYGBge3J9CndkPTFlLTUKbW9kZWwgPSBweSRFbWJlZGRpbmdEb3Qobl91c2Vycywgbl9tb3ZpZXMsIG5fZmFjdG9ycykkY3VkYSgpCm9wdCA9IHB5JG9wdGltJFNHRChtb2RlbCRwYXJhbWV0ZXJzKCksIDFlLTEsIHdlaWdodF9kZWNheT13ZCwgbW9tZW50dW09MC45KQpgYGAKClRoZSBtb2RlbCBkZXRhaWxzIGFyZSBsaXN0ZWQgYmVsb3cKYGBge3J9Cm1vZGVsCmBgYApNb2RlbCBpcyBmaXQuIEluIEp1cHl0ZXIgbm90ZWJvb2ssIGEgd2lkZ2V0IHNob3dzIGEgbmljZSBvdXRwdXQgb2YgcHJvZ3Jlc3MuIFRoYXQgb3V0cHV0IGRvZXNuJ3QgcmVuZGVyIHByb3Blcmx5IHdpdGhpbiBSU3R1ZGlvIGFuZCB0aGUgaHRtbCBnZW5lcmF0ZWQgZG9jdW1lbnQgaGFzIHRvbyBtdWNoIG91dHB1dC4gRm9yIG5vdywgSSBoYXZlIHR1cm5lZCBvZmYgb3V0cHV0IGFuZCBhbSBqdXN0IHN0b3JpbmcgdGhlIGZpbmFsIHZhbGlkYXRpb24gbG9zcyBtZXRyaWMgKE1TRSBsb3NzIGluIHRoaXMgY2FzZSkKYGBge3IgcmVzdWx0cz0iaGlkZSJ9Cm1kbGZpdCA9IHB5JGZpdChtb2RlbCwgZGF0YSwgM0wsIG9wdCwgcHkkRiRtc2VfbG9zcykKYGBgCmBgYHtyfQpwYXN0ZTAoIk1TRSBvZiB2YWxpZGF0aW9uIHNldCA9ICIsIHJvdW5kKG1kbGZpdFtbMV1dLCAzKSkKYGBgCgpNb2RlbCBydW4gZm9yIHNvbWUgbW9yZSBlcG9jaHMKYGBge3IgcmVzdWx0cz0iaGlkZSJ9CnB5JHNldF9scnMob3B0LCAwLjAxKQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYAoKQWx0ZXJuYXRlIHdheSB0byBjaGVjayB2YWxpZGF0aW9uIE1TRSBpcyB0byBleHBsaWNpdGx5IGdldCBwcmVkaWN0aW9ucyBmcm9tIG1vZGVsIGZvciB2YWxpZGF0aW9uIGRhdGEgYW5kIGNvbXBhcmUgdG8gdmFsaWRhdGlvbiBkYXRhCmBgYHtyfQp5dmFsX3ByZWRzID0gcHkkcHJlZGljdChtb2RlbCwgZGF0YSR2YWxfZGwpCnl2YWw9ZGF0YSR2YWxfeVssMV0KbXNlX2xvc3MgPSBtZWFuKCh5dmFsX3ByZWRzIC0geXZhbCkqKjIpCm1zZV9sb3NzCmBgYAoKIyMgTW9kZWwgMgpUaGlzIGlzIG1vZGVsIDEgd2l0aCBhZGRlZCBiaWFzIHRlcm0gZm9yIHVzZXJzIGFuZCBtb3ZpZXMuIEVhY2ggdXNlciBpcyAkaSQgcmVwcmVzZW50IGJ5IGFuIGVtYmVkZGluZyB2ZWN0b3IgJHVfaSQgY29uc2lzdGluZyBvZiBgbl9mYWN0b3JgIHZhbHVlcyBhbmQgYSB1c2VyIGJpYXMgdmFsdWUgJHViX2kkLiBTaW1pbGFybHkgYSBtb3ZpZSAkaiQgaXMgcmVwcmVzZW50ZWQgYnkgYW4gZW1iZWRkaW5nIHZlY3RvciAkbV9qJCBjb25zaXN0aW5nIG9mIGBuX2ZhY3RvcmAgdmFsdWVzIGFuZCBtb3ZpZSBiaWFzIHZhbHVlICRtYl9qJC4gSW4gYWRkdGlvbiwgdGhlIG91dHB1dCBpcyBjb250cmFpbmVkIHRvIGJlIGJldHdlZW4gbWluaW11bSByYXRpbmcgYHJfbWluYCBhbmQgbWF4aW11bSByYXRpbmcgYHJfbWF4YCB1c2luZyBhIHNpZ21vaWQgZnVuY3Rpb24uIFRoZSBtb2RlbCBvZiByYXRpbmcgJHJfe2lqfSQgZ2l2ZW4gYnkgdXNlciAkaSQgdG8gbW92aWUgJGokIGlzOgokJCByX3tpan0gPSBcZnJhY3tlXnt1X2leVHZfaiArIGJfaSArIG1fan19e2Vee3VfaV5Udl9qICsgYl9pICsgbV9qfSArIDF9KHJfe21heH0gLSByX3ttaW59KSArIHJfe21pbn0gJCQKYGBge3J9CnB5X3J1bl9zdHJpbmcoIgpkZWYgZ2V0X2VtYihuaSxuZik6CiAgICBlID0gbm4uRW1iZWRkaW5nKG5pLCBuZikKICAgIGUud2VpZ2h0LmRhdGEudW5pZm9ybV8oLTAuMDEsMC4wMSkKICAgIHJldHVybiBlCiAgICAgICAgICAgICAgCmNsYXNzIEVtYmVkZGluZ0RvdEJpYXMobm4uTW9kdWxlKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBuX3VzZXJzLCBuX21vdmllcywgbl9mYWN0b3JzLCBtaW5fcmF0aW5nLCBtYXhfcmF0aW5nKToKICAgICAgICBzdXBlcigpLl9faW5pdF9fKCkKICAgICAgICAoc2VsZi51LCBzZWxmLm0sIHNlbGYudWIsIHNlbGYubWIpID0gW2dldF9lbWIoKm8pIGZvciBvIGluIFsKICAgICAgICAgICAgICAobl91c2Vycywgbl9mYWN0b3JzKSwgKG5fbW92aWVzLCBuX2ZhY3RvcnMpLCAobl91c2VycywxKSwgKG5fbW92aWVzLDEpCiAgICAgICAgICAgICAgXV0KICAgICAgICBzZWxmLm1heF9yYXRpbmcgPSBtYXhfcmF0aW5nCiAgICAgICAgc2VsZi5taW5fcmF0aW5nID0gbWluX3JhdGluZwogICAgICAgICAgICAgIAogICAgZGVmIGZvcndhcmQoc2VsZiwgY2F0cywgY29udHMpOgogICAgICAgIHVzZXJzLG1vdmllcyA9IGNhdHNbOiwwXSxjYXRzWzosMV0KICAgICAgICB1bSA9IChzZWxmLnUodXNlcnMpKiBzZWxmLm0obW92aWVzKSkuc3VtKDEpCiAgICAgICAgcmVzID0gdW0gKyBzZWxmLnViKHVzZXJzKS5zcXVlZXplKCkgKyBzZWxmLm1iKG1vdmllcykuc3F1ZWV6ZSgpCiAgICAgICAgcmVzID0gRi5zaWdtb2lkKHJlcykgKiAoc2VsZi5tYXhfcmF0aW5nLXNlbGYubWluX3JhdGluZykgKyBzZWxmLm1pbl9yYXRpbmcKICAgICAgICByZXR1cm4gcmVzICAgICAgICAgICAgICAKICAgICAgICAgICAgICAiKQpgYGAKClRoZSBtb2RlbCBvYmplY3QgYG9wdGAgaXMgZGVmaW5lZCBhbmQgZml0CmBgYHtyfQp3ZD0yZS00Cm1pbl9yYXRpbmcgPSBtaW4ocmF0aW5ncyRyYXRpbmcpCm1heF9yYXRpbmcgPSBtYXgocmF0aW5ncyRyYXRpbmcpCm1vZGVsID0gcHkkRW1iZWRkaW5nRG90QmlhcyhuX3VzZXJzLCBuX21vdmllcywgbl9mYWN0b3JzLCBtaW5fcmF0aW5nLCBtYXhfcmF0aW5nKSRjdWRhKCkKb3B0ID0gcHkkb3B0aW0kU0dEKG1vZGVsJHBhcmFtZXRlcnMoKSwgMWUtMSwgd2VpZ2h0X2RlY2F5PXdkLCBtb21lbnR1bT0wLjkpCmBgYAoKVGhlIG1vZGVsIGRldGFpbHMgYXJlIGxpc3RlZCBiZWxvdwpgYGB7cn0KbW9kZWwKYGBgCgpgYGB7ciByZXN1bHRzID0gImhpZGUifQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYApDaGFuZ2UgdGhlIGxlYXJuaW5nIHJhdGUgYW5kIHJlZml0CmBgYHtyIHJlc3VsdHMgPSAiaGlkZSJ9CnB5JHNldF9scnMob3B0LCAwLjAxKQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYAoKIyMgTW9kZWwgMwoKVGhpcyBtb2RlbCBpcyBhIGRlZXAgbGVhcm5pbmcgbW9kZWwgd2l0aCB0aGUgZm9sbG93aW5nIGxheWVyczoKCjEuIElucHV0IC0gQ29uY2F0ZW5hdGlvbiBvZiB1c2VyIGVtYmVkZGluZyB2ZWN0b3IgYW5kIG1vdmllIGVtYmVkZGluZyB2ZWN0b3IgKHNpemUgPSBgMipuX2ZhY3RvcnNgKQoyLiBEcm9wb3V0IHdpdGggZHJvcG91dCByYXRlIGBwMT0wLjA1YAozLiBMaW5lYXIgZnVsbHkgY29ubmVjdGVkIGxheWVyIHdpdGggb3V0cHV0IHNpemUgYG5oID0gMTBgCjQuIFJlbHUKNS4gRHJvcG91dCB3aXRoIGRyb3BvdXQgcmF0ZSBgcDI9MC41YAo2LiBMaW5lYXIgZnVsbHkgY29ubmVjdGVkIGxheWVyIHdpdGggb3V0cHV0IHNpemUgMS4KNy4gQXBwbHkgc2lnbW9pZCBmdW5jdGlvbiBhbmQgc2NhbGUgdG8gYmUgYmV0d2VlbiBtaW4gYW5kIG1heCByYXRpbmcKCmBgYHtyfQpweV9ydW5fc3RyaW5nKCIKY2xhc3MgRW1iZWRkaW5nTmV0KG5uLk1vZHVsZSk6CiAgICBkZWYgX19pbml0X18oc2VsZiwgbl91c2Vycywgbl9tb3ZpZXMsIG5fZmFjdG9ycywgbWluX3JhdGluZywgbWF4X3JhdGluZywgbmg9MTAsIHAxPTAuMDUsIHAyPTAuNSk6CiAgICAgICAgc3VwZXIoKS5fX2luaXRfXygpCiAgICAgICAgKHNlbGYudSwgc2VsZi5tKSA9IFtnZXRfZW1iKCpvKSBmb3IgbyBpbiBbCiAgICAgICAgICAgIChuX3VzZXJzLCBuX2ZhY3RvcnMpLCAobl9tb3ZpZXMsIG5fZmFjdG9ycyldXQogICAgICAgIHNlbGYubGluMSA9IG5uLkxpbmVhcihuX2ZhY3RvcnMqMiwgbmgpCiAgICAgICAgc2VsZi5saW4yID0gbm4uTGluZWFyKG5oLCAxKQogICAgICAgIHNlbGYuZHJvcDEgPSBubi5Ecm9wb3V0KHAxKQogICAgICAgIHNlbGYuZHJvcDIgPSBubi5Ecm9wb3V0KHAyKQogICAgICAgIHNlbGYubWluX3JhdGluZyA9IG1pbl9yYXRpbmcKICAgICAgICBzZWxmLm1heF9yYXRpbmcgPSBtYXhfcmF0aW5nCiAgICAgICAgCiAgICBkZWYgZm9yd2FyZChzZWxmLCBjYXRzLCBjb250cyk6CiAgICAgICAgdXNlcnMsbW92aWVzID0gY2F0c1s6LDBdLGNhdHNbOiwxXQogICAgICAgIHggPSBzZWxmLmRyb3AxKHRvcmNoLmNhdChbc2VsZi51KHVzZXJzKSxzZWxmLm0obW92aWVzKV0sIGRpbT0xKSkKICAgICAgICB4ID0gc2VsZi5kcm9wMihGLnJlbHUoc2VsZi5saW4xKHgpKSkKICAgICAgICByZXR1cm4gRi5zaWdtb2lkKHNlbGYubGluMih4KSkgKiAoc2VsZi5tYXhfcmF0aW5nLXNlbGYubWluX3JhdGluZysxKSArIHNlbGYubWluX3JhdGluZy0wLjUKICAgICAgICAgICAiKQpgYGAKCk1vZGVsIGlzIGRlZmluZWQgYW5kIGZpdC4KYGBge3J9CndkPTFlLTUKbW9kZWwgPSBweSRFbWJlZGRpbmdOZXQobl91c2Vycywgbl9tb3ZpZXMsIG5fZmFjdG9ycywgbWluX3JhdGluZywgbWF4X3JhdGluZykkY3VkYSgpCm9wdCA9IHB5JG9wdGltJEFkYW0obW9kZWwkcGFyYW1ldGVycygpLCAxZS0zLCB3ZWlnaHRfZGVjYXk9d2QpCmBgYAoKVGhlIG1vZGVsIGRldGFpbHMgYXJlIGxpc3RlZCBiZWxvdwpgYGB7cn0KbW9kZWwKYGBgCgpgYGB7ciByZXN1bHRzID0gImhpZGUifQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYApMZWFybmluZyByYXRlIGlzIGNoYW5nZWQgYW5kIG1vZGVsIHJlZml0CmBgYHtyIHJlc3VsdHM9ImhpZGUifQpweSRzZXRfbHJzKG9wdCwgMC4wMSkKbWRsZml0ID0gcHkkZml0KG1vZGVsLCBkYXRhLCAzTCwgb3B0LCBweSRGJG1zZV9sb3NzKQpgYGAKYGBge3J9CnBhc3RlMCgiTVNFIG9mIHZhbGlkYXRpb24gc2V0ID0gIiwgcm91bmQobWRsZml0W1sxXV0sIDMpKQpgYGAKCiMjIFN1bW1hcnkKClRoaXMgc2hvd3MgaG93IGEgbW9kZWwgY2FuIGJlIGRldmVsb3BlZCBmcm9tIHNjcmF0Y2guIEJ1dCB0aGUgbW9kZWwgd291bGQgbmVlZCB0byBiZSBkZWZpbmVkIGFzIHB5dGhvbiBjbGFzcy4gSSBhbSBndWVzc2luZyBpZiBzb21lYm9keSBpcyBkZXZlbG9waW5nIGEgbW9kZWwgZnJvbSBzY3JhdGNoLCBpdCBtaWdodCBiZSBiZXR0ZXIgdG8gZG8gaXQganVzdCBpbiBweXRob24gYW5kIGNyZWF0ZSB3cmFwcGVyIGZ1bmN0aW9ucyB3aGljaCBjYW4gdGhlbiBieSB1c2VkIGJ5IFIuIAoKYGBge3J9CnNlc3Npb25JbmZvKCkKYGBg