We had created a R notebook version of the first portion of movielens python notebook from the Fastai Deep Learning for Coders (Part 1) where high level fastai functions were used to build and fit the model. This notebook tries to create the R version of the second portion of the movielens python notebook where Jeremy creates the collaborative filtering model form scratch.
This content is covered in videos of lecture 5 and lecture 6. It will be helpful to listen to the lectures before going through this notebook since the concepts of the model and approach are discussed in the lecture and this notebook is just a replication attempt of the material from the course using R.
Initial Setup
# import R libraries
library(reticulate)
library(ggplot2)
library(dplyr)
Attaching package: ‘dplyr’
The following objects are masked from ‘package:stats’:
filter, lag
The following objects are masked from ‘package:base’:
intersect, setdiff, setequal, union
Notes about python setup and machine used are covered in this R notebook.
use_python("/home/paperspace/anaconda3/envs/fastai/bin/python", required = TRUE)
use_condaenv("fastai")
py_config()
python: /home/paperspace/anaconda3/envs/fastai/bin/python
libpython: /home/paperspace/anaconda3/envs/fastai/lib/libpython3.6m.so
pythonhome: /home/paperspace/anaconda3/envs/fastai:/home/paperspace/anaconda3/envs/fastai
version: 3.6.4 |Anaconda, Inc.| (default, Dec 21 2017, 21:42:08) [GCC 7.2.0]
numpy: /home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/numpy
numpy_version: 1.13.3
NOTE: Python version was forced by use_python function
main = import_main()
bi = import_builtins()
# get relevant python imports
fstai_learner = import_from_path("fastai.learner", "../../fastai")
fstai_coldata = import_from_path("fastai.column_data", "../../fastai")
py_run_string("
from fastai.learner import *
from fastai.column_data import *
")
Get Data
The ratings dataset has the ratings for different users and movies. The movies dataset has the movie title information.
datapath = "../../data/ml-latest-small/"
ratings = read.csv(paste0(datapath, "ratings.csv"), stringsAsFactors = FALSE)
head(ratings)
movies = read.csv(paste0(datapath, "movies.csv"), stringsAsFactors = FALSE)
head(movies)
Working with PyTorch Tensors
Amat = matrix(c(1.0, 2.0, 3.0, 4.0), byrow = TRUE, ncol = 2)
Bmat = matrix(c(2.0, 2.0, 10.0, 10.0), byrow = TRUE, ncol = 2)
a = py$T(Amat)
a
1 2
3 4
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]
b = py$T(Bmat)
b
2 2
10 10
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]
In python when we use a*b
, it works fine but won’t work when calling from R since R doesn’t know how to use *
operator with Torch tensors. So here I have used the pyTorch mul
function for multiplication
a$mul(b)
2 4
30 40
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]
a$mul(b)$sum(1L) # note the use of 1L instead of 1 to ensure that we are passing integer
6
70
[torch.cuda.FloatTensor of size 2 (GPU 0)]
In PyTorch, we need to define a model as a Python class that inherits from nn.Module
. It has a specific method forward
which gives the recipe for computing the prediction given inputs. For example, if the prediction r
is a dot product of input vectors u
and m
, then it would be defined in the following class.
py_run_string("
class DotProduct(nn.Module):
def forward(self, u, m): return (u*m).sum(1)
")
A model can then be defined as the instance of the class and calling the forward method with inputs with give the prediction.
model = main$DotProduct()
model$forward(a, b)
6
70
[torch.cuda.FloatTensor of size 2 (GPU 0)]
Data Preparation for Use in PyTorch Model
The ratings dataset has userId and movieId fields. But for passing to pyTorch, we create a sequential index for both.
u_unique = unique(ratings$userId)
user2idx = as.integer(seq(0, length(u_unique) - 1))
names(user2idx) = u_unique
user2idx[1:10]
1 2 3 4 5 6 7 8 9 10
0 1 2 3 4 5 6 7 8 9
m_unique = unique(ratings$movieId)
movie2idx = as.integer(seq(0, length(m_unique) - 1))
names(movie2idx) = m_unique
movie2idx[1:10]
31 1029 1061 1129 1172 1263 1287 1293 1339 1343
0 1 2 3 4 5 6 7 8 9
ratings$userIdx = user2idx[as.character(ratings$userId)]
ratings$movieIdx = movie2idx[as.character(ratings$movieId)]
n_users = length(u_unique)
n_movies = length(m_unique)
n_users; n_movies
[1] 671
[1] 9066
Model 1
Each user is \(i\) represent by an embedding vector \(u_i\) consisting of n_factor
values. Similarly a movie \(j\) is represented by an embedding vector \(m_j\) consisting of n_factor
values. The model of rating \(r_{ij}\) given by user \(i\) to movie \(j\) is: \[ r_{ij} = u_i^Tv_j\] The class below defines the model. The constructor for class passes other inputs and initializes model parameters.
py_run_string("
class EmbeddingDot(nn.Module):
def __init__(self, n_users, n_movies, n_factors):
super().__init__()
self.u = nn.Embedding(n_users, n_factors)
self.m = nn.Embedding(n_movies, n_factors)
self.u.weight.data.uniform_(0,0.05)
self.m.weight.data.uniform_(0,0.05)
def forward(self, cats, conts):
users,movies = cats[:,0],cats[:,1]
u,m = self.u(users),self.m(movies)
return (u*m).sum(1)
")
The x dataframe just includes the sequential user id and movie id. y is a numpy array of ratings with type float32.
x = ratings[, c("userIdx", "movieIdx")]
y = np_array(ratings$rating)$astype(py$np$float32)
The list of validation data rows are selected and n_factors
is set to 50.
val_idxs = py$get_cv_idxs(nrow(ratings))
val_idxs = as.integer(val_idxs)
n_factors = 50L
The data loader object is created
data = py$ColumnarModelData$from_data_frame(datapath, val_idxs, r_to_py(x), y, c("userIdx", "movieIdx"), 64L)
Model is defined using the defined class EmbeddingDot
. Based on the model, the optimizer opt
is defined.
wd=1e-5
model = py$EmbeddingDot(n_users, n_movies, n_factors)$cuda()
opt = py$optim$SGD(model$parameters(), 1e-1, weight_decay=wd, momentum=0.9)
The model details are listed below
model
EmbeddingDot(
(u): Embedding(671, 50)
(m): Embedding(9066, 50)
)
Model is fit. In Jupyter notebook, a widget shows a nice output of progress. That output doesn’t render properly within RStudio and the html generated document has too much output. For now, I have turned off output and am just storing the final validation loss metric (MSE loss in this case)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 1.219"
Model run for some more epochs
py$set_lrs(opt, 0.01)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 1.131"
Alternate way to check validation MSE is to explicitly get predictions from model for validation data and compare to validation data
yval_preds = py$predict(model, data$val_dl)
yval=data$val_y[,1]
mse_loss = mean((yval_preds - yval)**2)
mse_loss
[1] 1.131446
Model 2
This is model 1 with added bias term for users and movies. Each user is \(i\) represent by an embedding vector \(u_i\) consisting of n_factor
values and a user bias value \(ub_i\). Similarly a movie \(j\) is represented by an embedding vector \(m_j\) consisting of n_factor
values and movie bias value \(mb_j\). In addtion, the output is contrained to be between minimum rating r_min
and maximum rating r_max
using a sigmoid function. The model of rating \(r_{ij}\) given by user \(i\) to movie \(j\) is: \[ r_{ij} = \frac{e^{u_i^Tv_j + b_i + m_j}}{e^{u_i^Tv_j + b_i + m_j} + 1}(r_{max} - r_{min}) + r_{min} \]
py_run_string("
def get_emb(ni,nf):
e = nn.Embedding(ni, nf)
e.weight.data.uniform_(-0.01,0.01)
return e
class EmbeddingDotBias(nn.Module):
def __init__(self, n_users, n_movies, n_factors, min_rating, max_rating):
super().__init__()
(self.u, self.m, self.ub, self.mb) = [get_emb(*o) for o in [
(n_users, n_factors), (n_movies, n_factors), (n_users,1), (n_movies,1)
]]
self.max_rating = max_rating
self.min_rating = min_rating
def forward(self, cats, conts):
users,movies = cats[:,0],cats[:,1]
um = (self.u(users)* self.m(movies)).sum(1)
res = um + self.ub(users).squeeze() + self.mb(movies).squeeze()
res = F.sigmoid(res) * (self.max_rating-self.min_rating) + self.min_rating
return res
")
The model object opt
is defined and fit
wd=2e-4
min_rating = min(ratings$rating)
max_rating = max(ratings$rating)
model = py$EmbeddingDotBias(n_users, n_movies, n_factors, min_rating, max_rating)$cuda()
opt = py$optim$SGD(model$parameters(), 1e-1, weight_decay=wd, momentum=0.9)
The model details are listed below
model
EmbeddingDotBias(
(u): Embedding(671, 50)
(m): Embedding(9066, 50)
(ub): Embedding(671, 1)
(mb): Embedding(9066, 1)
)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.807"
Change the learning rate and refit
py$set_lrs(opt, 0.01)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.801"
Model 3
This model is a deep learning model with the following layers:
- Input - Concatenation of user embedding vector and movie embedding vector (size =
2*n_factors
)
- Dropout with dropout rate
p1=0.05
- Linear fully connected layer with output size
nh = 10
- Relu
- Dropout with dropout rate
p2=0.5
- Linear fully connected layer with output size 1.
- Apply sigmoid function and scale to be between min and max rating
py_run_string("
class EmbeddingNet(nn.Module):
def __init__(self, n_users, n_movies, n_factors, min_rating, max_rating, nh=10, p1=0.05, p2=0.5):
super().__init__()
(self.u, self.m) = [get_emb(*o) for o in [
(n_users, n_factors), (n_movies, n_factors)]]
self.lin1 = nn.Linear(n_factors*2, nh)
self.lin2 = nn.Linear(nh, 1)
self.drop1 = nn.Dropout(p1)
self.drop2 = nn.Dropout(p2)
self.min_rating = min_rating
self.max_rating = max_rating
def forward(self, cats, conts):
users,movies = cats[:,0],cats[:,1]
x = self.drop1(torch.cat([self.u(users),self.m(movies)], dim=1))
x = self.drop2(F.relu(self.lin1(x)))
return F.sigmoid(self.lin2(x)) * (self.max_rating-self.min_rating+1) + self.min_rating-0.5
")
Model is defined and fit.
wd=1e-5
model = py$EmbeddingNet(n_users, n_movies, n_factors, min_rating, max_rating)$cuda()
opt = py$optim$Adam(model$parameters(), 1e-3, weight_decay=wd)
The model details are listed below
model
EmbeddingNet(
(u): Embedding(671, 50)
(m): Embedding(9066, 50)
(lin1): Linear(in_features=100, out_features=10)
(lin2): Linear(in_features=10, out_features=1)
(drop1): Dropout(p=0.05)
(drop2): Dropout(p=0.5)
)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.784"
Learning rate is changed and model refit
py$set_lrs(opt, 0.01)
mdlfit = py$fit(model, data, 3L, opt, py$F$mse_loss)
paste0("MSE of validation set = ", round(mdlfit[[1]], 3))
[1] "MSE of validation set = 0.828"
Summary
This shows how a model can be developed from scratch. But the model would need to be defined as python class. I am guessing if somebody is developing a model from scratch, it might be better to do it just in python and create wrapper functions which can then by used by R.
sessionInfo()
R version 3.4.3 (2017-11-30)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 16.04.3 LTS
Matrix products: default
BLAS: /usr/lib/libblas/libblas.so.3.6.0
LAPACK: /home/paperspace/anaconda3/envs/fastai/lib/libmkl_intel_lp64.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=en_US.UTF-8
[4] LC_COLLATE=en_US.UTF-8 LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] dplyr_0.7.4 ggplot2_2.2.1 reticulate_1.6
loaded via a namespace (and not attached):
[1] Rcpp_0.12.14 assertthat_0.2.0 R6_2.2.2 grid_3.4.3 plyr_1.8.4
[6] jsonlite_1.5 gtable_0.2.0 magrittr_1.5 scales_0.5.0 pillar_1.1.0
[11] rlang_0.1.6 lazyeval_0.2.1 bindrcpp_0.2 tools_3.4.3 glue_1.2.0
[16] munsell_0.4.3 yaml_2.1.18 compiler_3.4.3 pkgconfig_2.0.1 colorspace_1.3-2
[21] bindr_0.1.1 knitr_1.20 tibble_1.4.2
LS0tCnRpdGxlOiAiRmFzdGFpIENvbGxhYm9yYXRpdmUgRmlsdGVyaW5nIChmcm9tIFNjcmF0Y2gpIHdpdGggUiBhbmQgUmV0aWN1bGF0ZSIKb3V0cHV0OiBodG1sX25vdGVib29rCmVkaXRvcl9vcHRpb25zOiAKICBjaHVua19vdXRwdXRfdHlwZTogaW5saW5lCi0tLQoKV2UgaGFkIGNyZWF0ZWQgYSBbUiBub3RlYm9va10oaHR0cHM6Ly9ub3Rlc29mZGFiYmxlci5naXRodWIuaW8vZmFzdGFpX2RsMV93aXRoUi9tb3ZpZUxlbnMubmIuaHRtbCkgdmVyc2lvbiBvZiB0aGUgZmlyc3QgcG9ydGlvbiBvZiBtb3ZpZWxlbnMgW3B5dGhvbiBub3RlYm9va10oaHR0cHM6Ly9naXRodWIuY29tL2Zhc3RhaS9mYXN0YWkvYmxvYi9tYXN0ZXIvY291cnNlcy9kbDEvbGVzc29uNS1tb3ZpZWxlbnMuaXB5bmIpIGZyb20gdGhlIFtGYXN0YWkgRGVlcCBMZWFybmluZyBmb3IgQ29kZXJzIChQYXJ0IDEpXShodHRwOi8vY291cnNlLmZhc3QuYWkvKSB3aGVyZSBoaWdoIGxldmVsIGZhc3RhaSBmdW5jdGlvbnMgd2VyZSB1c2VkIHRvIGJ1aWxkIGFuZCBmaXQgdGhlIG1vZGVsLiBUaGlzIG5vdGVib29rIHRyaWVzIHRvIGNyZWF0ZSB0aGUgUiB2ZXJzaW9uIG9mIHRoZSBzZWNvbmQgcG9ydGlvbiBvZiB0aGUgbW92aWVsZW5zIHB5dGhvbiBub3RlYm9vayB3aGVyZSBKZXJlbXkgY3JlYXRlcyB0aGUgY29sbGFib3JhdGl2ZSBmaWx0ZXJpbmcgbW9kZWwgZm9ybSBzY3JhdGNoLiAKClRoaXMgY29udGVudCBpcyBjb3ZlcmVkIGluIHZpZGVvcyBvZiBbbGVjdHVyZSA1XShodHRwOi8vY291cnNlLmZhc3QuYWkvbGVzc29ucy9sZXNzb241Lmh0bWwpIGFuZCBbbGVjdHVyZSA2XShodHRwOi8vY291cnNlLmZhc3QuYWkvbGVzc29ucy9sZXNzb242Lmh0bWwpLiBJdCB3aWxsIGJlIGhlbHBmdWwgdG8gbGlzdGVuIHRvIHRoZSBsZWN0dXJlcyBiZWZvcmUgZ29pbmcgdGhyb3VnaCB0aGlzIG5vdGVib29rIHNpbmNlIHRoZSBjb25jZXB0cyBvZiB0aGUgbW9kZWwgYW5kIGFwcHJvYWNoIGFyZSBkaXNjdXNzZWQgaW4gdGhlIGxlY3R1cmUgYW5kIHRoaXMgbm90ZWJvb2sgaXMganVzdCBhIHJlcGxpY2F0aW9uIGF0dGVtcHQgb2YgdGhlIG1hdGVyaWFsIGZyb20gdGhlIGNvdXJzZSB1c2luZyBSLgoKIyMgSW5pdGlhbCBTZXR1cAoKYGBge3J9CiMgaW1wb3J0IFIgbGlicmFyaWVzCmxpYnJhcnkocmV0aWN1bGF0ZSkKbGlicmFyeShnZ3Bsb3QyKQpsaWJyYXJ5KGRwbHlyKQoKYGBgCgpOb3RlcyBhYm91dCBweXRob24gc2V0dXAgYW5kIG1hY2hpbmUgdXNlZCBhcmUgY292ZXJlZCBpbiB0aGlzIFtSIG5vdGVib29rXShodHRwczovL25vdGVzb2ZkYWJibGVyLmdpdGh1Yi5pby9mYXN0YWlfZGwxX3dpdGhSL21vdmllTGVucy5uYi5odG1sKS4KCmBgYHtyfQoKdXNlX3B5dGhvbigiL2hvbWUvcGFwZXJzcGFjZS9hbmFjb25kYTMvZW52cy9mYXN0YWkvYmluL3B5dGhvbiIsIHJlcXVpcmVkID0gVFJVRSkKdXNlX2NvbmRhZW52KCJmYXN0YWkiKQpweV9jb25maWcoKQoKbWFpbiA9IGltcG9ydF9tYWluKCkKYmkgPSBpbXBvcnRfYnVpbHRpbnMoKQpgYGAKCmBgYHtyfQojIGdldCByZWxldmFudCBweXRob24gaW1wb3J0cwpmc3RhaV9sZWFybmVyID0gaW1wb3J0X2Zyb21fcGF0aCgiZmFzdGFpLmxlYXJuZXIiLCAiLi4vLi4vZmFzdGFpIikKZnN0YWlfY29sZGF0YSA9IGltcG9ydF9mcm9tX3BhdGgoImZhc3RhaS5jb2x1bW5fZGF0YSIsICIuLi8uLi9mYXN0YWkiKQpgYGAKCmBgYHtyfQpweV9ydW5fc3RyaW5nKCIKZnJvbSBmYXN0YWkubGVhcm5lciBpbXBvcnQgKgpmcm9tIGZhc3RhaS5jb2x1bW5fZGF0YSBpbXBvcnQgKgogICAgICAgICAgICAgICIpCmBgYAoKIyMgR2V0IERhdGEKClRoZSByYXRpbmdzIGRhdGFzZXQgaGFzIHRoZSByYXRpbmdzIGZvciBkaWZmZXJlbnQgdXNlcnMgYW5kIG1vdmllcy4gVGhlIG1vdmllcyBkYXRhc2V0IGhhcyB0aGUgbW92aWUgdGl0bGUgaW5mb3JtYXRpb24uCgpgYGB7cn0KZGF0YXBhdGggPSAiLi4vLi4vZGF0YS9tbC1sYXRlc3Qtc21hbGwvIgoKcmF0aW5ncyA9IHJlYWQuY3N2KHBhc3RlMChkYXRhcGF0aCwgInJhdGluZ3MuY3N2IiksIHN0cmluZ3NBc0ZhY3RvcnMgPSBGQUxTRSkKaGVhZChyYXRpbmdzKQoKbW92aWVzID0gcmVhZC5jc3YocGFzdGUwKGRhdGFwYXRoLCAibW92aWVzLmNzdiIpLCBzdHJpbmdzQXNGYWN0b3JzID0gRkFMU0UpCmhlYWQobW92aWVzKQpgYGAKCiMjIFdvcmtpbmcgd2l0aCBQeVRvcmNoIFRlbnNvcnMKYGBge3J9CkFtYXQgPSBtYXRyaXgoYygxLjAsIDIuMCwgMy4wLCA0LjApLCBieXJvdyA9IFRSVUUsIG5jb2wgPSAyKQpCbWF0ID0gbWF0cml4KGMoMi4wLCAyLjAsIDEwLjAsIDEwLjApLCBieXJvdyA9IFRSVUUsIG5jb2wgPSAyKQoKYSA9IHB5JFQoQW1hdCkKYQoKYiA9IHB5JFQoQm1hdCkKYgpgYGAKSW4gcHl0aG9uIHdoZW4gd2UgdXNlIGBhKmJgLCBpdCB3b3JrcyBmaW5lIGJ1dCB3b24ndCB3b3JrIHdoZW4gY2FsbGluZyBmcm9tIFIgc2luY2UgUiBkb2Vzbid0IGtub3cgaG93IHRvIHVzZSBgKmAgb3BlcmF0b3Igd2l0aCBUb3JjaCB0ZW5zb3JzLiBTbyBoZXJlIEkgaGF2ZSB1c2VkIHRoZSBweVRvcmNoIGBtdWxgIGZ1bmN0aW9uIGZvciBtdWx0aXBsaWNhdGlvbgpgYGB7cn0KYSRtdWwoYikKYSRtdWwoYikkc3VtKDFMKSAjIG5vdGUgdGhlIHVzZSBvZiAxTCBpbnN0ZWFkIG9mIDEgdG8gZW5zdXJlIHRoYXQgd2UgYXJlIHBhc3NpbmcgaW50ZWdlcgpgYGAKCkluIFB5VG9yY2gsIHdlIG5lZWQgdG8gZGVmaW5lIGEgbW9kZWwgYXMgYSBQeXRob24gY2xhc3MgdGhhdCBpbmhlcml0cyBmcm9tIGBubi5Nb2R1bGVgLiBJdCBoYXMgYSBzcGVjaWZpYyBtZXRob2QgYGZvcndhcmRgIHdoaWNoIGdpdmVzIHRoZSByZWNpcGUgZm9yIGNvbXB1dGluZyB0aGUgcHJlZGljdGlvbiBnaXZlbiBpbnB1dHMuIEZvciBleGFtcGxlLCBpZiB0aGUgcHJlZGljdGlvbiBgcmAgaXMgYSBkb3QgcHJvZHVjdCBvZiBpbnB1dCB2ZWN0b3JzIGB1YCBhbmQgYG1gLCB0aGVuIGl0IHdvdWxkIGJlIGRlZmluZWQgaW4gdGhlIGZvbGxvd2luZyBjbGFzcy4KYGBge3J9CnB5X3J1bl9zdHJpbmcoIgpjbGFzcyBEb3RQcm9kdWN0KG5uLk1vZHVsZSk6CiAgICBkZWYgZm9yd2FyZChzZWxmLCB1LCBtKTogcmV0dXJuICh1Km0pLnN1bSgxKSAgICAgICAgICAgICAgCiAgICAgICAgICAgICAgIikKYGBgCkEgbW9kZWwgY2FuIHRoZW4gYmUgZGVmaW5lZCBhcyB0aGUgaW5zdGFuY2Ugb2YgdGhlIGNsYXNzIGFuZCBjYWxsaW5nIHRoZSBmb3J3YXJkIG1ldGhvZCB3aXRoIGlucHV0cyB3aXRoIGdpdmUgdGhlIHByZWRpY3Rpb24uCmBgYHtyfQptb2RlbCA9IG1haW4kRG90UHJvZHVjdCgpCm1vZGVsJGZvcndhcmQoYSwgYikKYGBgCgojIyBEYXRhIFByZXBhcmF0aW9uIGZvciBVc2UgaW4gUHlUb3JjaCBNb2RlbAoKVGhlIHJhdGluZ3MgZGF0YXNldCBoYXMgdXNlcklkIGFuZCBtb3ZpZUlkIGZpZWxkcy4gQnV0IGZvciBwYXNzaW5nIHRvIHB5VG9yY2gsIHdlIGNyZWF0ZSBhIHNlcXVlbnRpYWwgaW5kZXggZm9yIGJvdGguCgpgYGB7cn0KdV91bmlxdWUgPSB1bmlxdWUocmF0aW5ncyR1c2VySWQpCnVzZXIyaWR4ID0gYXMuaW50ZWdlcihzZXEoMCwgbGVuZ3RoKHVfdW5pcXVlKSAtIDEpKQpuYW1lcyh1c2VyMmlkeCkgPSB1X3VuaXF1ZQp1c2VyMmlkeFsxOjEwXQoKbV91bmlxdWUgPSB1bmlxdWUocmF0aW5ncyRtb3ZpZUlkKQptb3ZpZTJpZHggPSBhcy5pbnRlZ2VyKHNlcSgwLCBsZW5ndGgobV91bmlxdWUpIC0gMSkpCm5hbWVzKG1vdmllMmlkeCkgPSBtX3VuaXF1ZQptb3ZpZTJpZHhbMToxMF0KYGBgCgpgYGB7cn0KcmF0aW5ncyR1c2VySWR4ID0gdXNlcjJpZHhbYXMuY2hhcmFjdGVyKHJhdGluZ3MkdXNlcklkKV0KcmF0aW5ncyRtb3ZpZUlkeCA9IG1vdmllMmlkeFthcy5jaGFyYWN0ZXIocmF0aW5ncyRtb3ZpZUlkKV0KCm5fdXNlcnMgPSBsZW5ndGgodV91bmlxdWUpCm5fbW92aWVzID0gbGVuZ3RoKG1fdW5pcXVlKQoKbl91c2Vyczsgbl9tb3ZpZXMKYGBgCgojIyBNb2RlbCAxCgpFYWNoIHVzZXIgaXMgJGkkIHJlcHJlc2VudCBieSBhbiBlbWJlZGRpbmcgdmVjdG9yICR1X2kkIGNvbnNpc3Rpbmcgb2YgYG5fZmFjdG9yYCB2YWx1ZXMuIFNpbWlsYXJseSBhIG1vdmllICRqJCBpcyByZXByZXNlbnRlZCBieSBhbiBlbWJlZGRpbmcgdmVjdG9yICRtX2okIGNvbnNpc3Rpbmcgb2YgYG5fZmFjdG9yYCB2YWx1ZXMuIFRoZSBtb2RlbCBvZiByYXRpbmcgJHJfe2lqfSQgZ2l2ZW4gYnkgdXNlciAkaSQgdG8gbW92aWUgJGokIGlzOgokJCByX3tpan0gPSB1X2leVHZfaiQkClRoZSBjbGFzcyBiZWxvdyBkZWZpbmVzIHRoZSBtb2RlbC4gVGhlIGNvbnN0cnVjdG9yIGZvciBjbGFzcyBwYXNzZXMgb3RoZXIgaW5wdXRzIGFuZCBpbml0aWFsaXplcyBtb2RlbCBwYXJhbWV0ZXJzLiAKYGBge3J9CnB5X3J1bl9zdHJpbmcoIgpjbGFzcyBFbWJlZGRpbmdEb3Qobm4uTW9kdWxlKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBuX3VzZXJzLCBuX21vdmllcywgbl9mYWN0b3JzKToKICAgICAgICBzdXBlcigpLl9faW5pdF9fKCkKICAgICAgICBzZWxmLnUgPSBubi5FbWJlZGRpbmcobl91c2Vycywgbl9mYWN0b3JzKQogICAgICAgIHNlbGYubSA9IG5uLkVtYmVkZGluZyhuX21vdmllcywgbl9mYWN0b3JzKQogICAgICAgIHNlbGYudS53ZWlnaHQuZGF0YS51bmlmb3JtXygwLDAuMDUpCiAgICAgICAgc2VsZi5tLndlaWdodC5kYXRhLnVuaWZvcm1fKDAsMC4wNSkKICAgICAgICAKICAgIGRlZiBmb3J3YXJkKHNlbGYsIGNhdHMsIGNvbnRzKToKICAgICAgICB1c2Vycyxtb3ZpZXMgPSBjYXRzWzosMF0sY2F0c1s6LDFdCiAgICAgICAgdSxtID0gc2VsZi51KHVzZXJzKSxzZWxmLm0obW92aWVzKQogICAgICAgIHJldHVybiAodSptKS5zdW0oMSkgICAgICAgICAgICAgIAogICAgICAgICAgICAgICIpCmBgYAoKVGhlIHggZGF0YWZyYW1lIGp1c3QgaW5jbHVkZXMgdGhlIHNlcXVlbnRpYWwgdXNlciBpZCBhbmQgbW92aWUgaWQuIHkgaXMgYSBudW1weSBhcnJheSBvZiByYXRpbmdzIHdpdGggdHlwZSBmbG9hdDMyLgpgYGB7cn0KeCA9IHJhdGluZ3NbLCBjKCJ1c2VySWR4IiwgIm1vdmllSWR4IildCnkgPSBucF9hcnJheShyYXRpbmdzJHJhdGluZykkYXN0eXBlKHB5JG5wJGZsb2F0MzIpCmBgYAoKVGhlIGxpc3Qgb2YgdmFsaWRhdGlvbiBkYXRhIHJvd3MgYXJlIHNlbGVjdGVkIGFuZCBgbl9mYWN0b3JzYCBpcyBzZXQgdG8gNTAuCmBgYHtyfQp2YWxfaWR4cyA9IHB5JGdldF9jdl9pZHhzKG5yb3cocmF0aW5ncykpCnZhbF9pZHhzID0gYXMuaW50ZWdlcih2YWxfaWR4cykKbl9mYWN0b3JzID0gNTBMCmBgYAoKVGhlIGRhdGEgbG9hZGVyIG9iamVjdCBpcyBjcmVhdGVkCmBgYHtyfQpkYXRhID0gcHkkQ29sdW1uYXJNb2RlbERhdGEkZnJvbV9kYXRhX2ZyYW1lKGRhdGFwYXRoLCB2YWxfaWR4cywgcl90b19weSh4KSwgeSwgYygidXNlcklkeCIsICJtb3ZpZUlkeCIpLCA2NEwpCmBgYApNb2RlbCBpcyBkZWZpbmVkIHVzaW5nIHRoZSBkZWZpbmVkIGNsYXNzIGBFbWJlZGRpbmdEb3RgLiBCYXNlZCBvbiB0aGUgbW9kZWwsIHRoZSBvcHRpbWl6ZXIgYG9wdGAgaXMgZGVmaW5lZC4KYGBge3J9CndkPTFlLTUKbW9kZWwgPSBweSRFbWJlZGRpbmdEb3Qobl91c2Vycywgbl9tb3ZpZXMsIG5fZmFjdG9ycykkY3VkYSgpCm9wdCA9IHB5JG9wdGltJFNHRChtb2RlbCRwYXJhbWV0ZXJzKCksIDFlLTEsIHdlaWdodF9kZWNheT13ZCwgbW9tZW50dW09MC45KQpgYGAKClRoZSBtb2RlbCBkZXRhaWxzIGFyZSBsaXN0ZWQgYmVsb3cKYGBge3J9Cm1vZGVsCmBgYApNb2RlbCBpcyBmaXQuIEluIEp1cHl0ZXIgbm90ZWJvb2ssIGEgd2lkZ2V0IHNob3dzIGEgbmljZSBvdXRwdXQgb2YgcHJvZ3Jlc3MuIFRoYXQgb3V0cHV0IGRvZXNuJ3QgcmVuZGVyIHByb3Blcmx5IHdpdGhpbiBSU3R1ZGlvIGFuZCB0aGUgaHRtbCBnZW5lcmF0ZWQgZG9jdW1lbnQgaGFzIHRvbyBtdWNoIG91dHB1dC4gRm9yIG5vdywgSSBoYXZlIHR1cm5lZCBvZmYgb3V0cHV0IGFuZCBhbSBqdXN0IHN0b3JpbmcgdGhlIGZpbmFsIHZhbGlkYXRpb24gbG9zcyBtZXRyaWMgKE1TRSBsb3NzIGluIHRoaXMgY2FzZSkKYGBge3IgcmVzdWx0cz0iaGlkZSJ9Cm1kbGZpdCA9IHB5JGZpdChtb2RlbCwgZGF0YSwgM0wsIG9wdCwgcHkkRiRtc2VfbG9zcykKYGBgCmBgYHtyfQpwYXN0ZTAoIk1TRSBvZiB2YWxpZGF0aW9uIHNldCA9ICIsIHJvdW5kKG1kbGZpdFtbMV1dLCAzKSkKYGBgCgpNb2RlbCBydW4gZm9yIHNvbWUgbW9yZSBlcG9jaHMKYGBge3IgcmVzdWx0cz0iaGlkZSJ9CnB5JHNldF9scnMob3B0LCAwLjAxKQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYAoKQWx0ZXJuYXRlIHdheSB0byBjaGVjayB2YWxpZGF0aW9uIE1TRSBpcyB0byBleHBsaWNpdGx5IGdldCBwcmVkaWN0aW9ucyBmcm9tIG1vZGVsIGZvciB2YWxpZGF0aW9uIGRhdGEgYW5kIGNvbXBhcmUgdG8gdmFsaWRhdGlvbiBkYXRhCmBgYHtyfQp5dmFsX3ByZWRzID0gcHkkcHJlZGljdChtb2RlbCwgZGF0YSR2YWxfZGwpCnl2YWw9ZGF0YSR2YWxfeVssMV0KbXNlX2xvc3MgPSBtZWFuKCh5dmFsX3ByZWRzIC0geXZhbCkqKjIpCm1zZV9sb3NzCmBgYAoKIyMgTW9kZWwgMgpUaGlzIGlzIG1vZGVsIDEgd2l0aCBhZGRlZCBiaWFzIHRlcm0gZm9yIHVzZXJzIGFuZCBtb3ZpZXMuIEVhY2ggdXNlciBpcyAkaSQgcmVwcmVzZW50IGJ5IGFuIGVtYmVkZGluZyB2ZWN0b3IgJHVfaSQgY29uc2lzdGluZyBvZiBgbl9mYWN0b3JgIHZhbHVlcyBhbmQgYSB1c2VyIGJpYXMgdmFsdWUgJHViX2kkLiBTaW1pbGFybHkgYSBtb3ZpZSAkaiQgaXMgcmVwcmVzZW50ZWQgYnkgYW4gZW1iZWRkaW5nIHZlY3RvciAkbV9qJCBjb25zaXN0aW5nIG9mIGBuX2ZhY3RvcmAgdmFsdWVzIGFuZCBtb3ZpZSBiaWFzIHZhbHVlICRtYl9qJC4gSW4gYWRkdGlvbiwgdGhlIG91dHB1dCBpcyBjb250cmFpbmVkIHRvIGJlIGJldHdlZW4gbWluaW11bSByYXRpbmcgYHJfbWluYCBhbmQgbWF4aW11bSByYXRpbmcgYHJfbWF4YCB1c2luZyBhIHNpZ21vaWQgZnVuY3Rpb24uIFRoZSBtb2RlbCBvZiByYXRpbmcgJHJfe2lqfSQgZ2l2ZW4gYnkgdXNlciAkaSQgdG8gbW92aWUgJGokIGlzOgokJCByX3tpan0gPSBcZnJhY3tlXnt1X2leVHZfaiArIGJfaSArIG1fan19e2Vee3VfaV5Udl9qICsgYl9pICsgbV9qfSArIDF9KHJfe21heH0gLSByX3ttaW59KSArIHJfe21pbn0gJCQKYGBge3J9CnB5X3J1bl9zdHJpbmcoIgpkZWYgZ2V0X2VtYihuaSxuZik6CiAgICBlID0gbm4uRW1iZWRkaW5nKG5pLCBuZikKICAgIGUud2VpZ2h0LmRhdGEudW5pZm9ybV8oLTAuMDEsMC4wMSkKICAgIHJldHVybiBlCiAgICAgICAgICAgICAgCmNsYXNzIEVtYmVkZGluZ0RvdEJpYXMobm4uTW9kdWxlKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBuX3VzZXJzLCBuX21vdmllcywgbl9mYWN0b3JzLCBtaW5fcmF0aW5nLCBtYXhfcmF0aW5nKToKICAgICAgICBzdXBlcigpLl9faW5pdF9fKCkKICAgICAgICAoc2VsZi51LCBzZWxmLm0sIHNlbGYudWIsIHNlbGYubWIpID0gW2dldF9lbWIoKm8pIGZvciBvIGluIFsKICAgICAgICAgICAgICAobl91c2Vycywgbl9mYWN0b3JzKSwgKG5fbW92aWVzLCBuX2ZhY3RvcnMpLCAobl91c2VycywxKSwgKG5fbW92aWVzLDEpCiAgICAgICAgICAgICAgXV0KICAgICAgICBzZWxmLm1heF9yYXRpbmcgPSBtYXhfcmF0aW5nCiAgICAgICAgc2VsZi5taW5fcmF0aW5nID0gbWluX3JhdGluZwogICAgICAgICAgICAgIAogICAgZGVmIGZvcndhcmQoc2VsZiwgY2F0cywgY29udHMpOgogICAgICAgIHVzZXJzLG1vdmllcyA9IGNhdHNbOiwwXSxjYXRzWzosMV0KICAgICAgICB1bSA9IChzZWxmLnUodXNlcnMpKiBzZWxmLm0obW92aWVzKSkuc3VtKDEpCiAgICAgICAgcmVzID0gdW0gKyBzZWxmLnViKHVzZXJzKS5zcXVlZXplKCkgKyBzZWxmLm1iKG1vdmllcykuc3F1ZWV6ZSgpCiAgICAgICAgcmVzID0gRi5zaWdtb2lkKHJlcykgKiAoc2VsZi5tYXhfcmF0aW5nLXNlbGYubWluX3JhdGluZykgKyBzZWxmLm1pbl9yYXRpbmcKICAgICAgICByZXR1cm4gcmVzICAgICAgICAgICAgICAKICAgICAgICAgICAgICAiKQpgYGAKClRoZSBtb2RlbCBvYmplY3QgYG9wdGAgaXMgZGVmaW5lZCBhbmQgZml0CmBgYHtyfQp3ZD0yZS00Cm1pbl9yYXRpbmcgPSBtaW4ocmF0aW5ncyRyYXRpbmcpCm1heF9yYXRpbmcgPSBtYXgocmF0aW5ncyRyYXRpbmcpCm1vZGVsID0gcHkkRW1iZWRkaW5nRG90QmlhcyhuX3VzZXJzLCBuX21vdmllcywgbl9mYWN0b3JzLCBtaW5fcmF0aW5nLCBtYXhfcmF0aW5nKSRjdWRhKCkKb3B0ID0gcHkkb3B0aW0kU0dEKG1vZGVsJHBhcmFtZXRlcnMoKSwgMWUtMSwgd2VpZ2h0X2RlY2F5PXdkLCBtb21lbnR1bT0wLjkpCmBgYAoKVGhlIG1vZGVsIGRldGFpbHMgYXJlIGxpc3RlZCBiZWxvdwpgYGB7cn0KbW9kZWwKYGBgCgpgYGB7ciByZXN1bHRzID0gImhpZGUifQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYApDaGFuZ2UgdGhlIGxlYXJuaW5nIHJhdGUgYW5kIHJlZml0CmBgYHtyIHJlc3VsdHMgPSAiaGlkZSJ9CnB5JHNldF9scnMob3B0LCAwLjAxKQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYAoKIyMgTW9kZWwgMwoKVGhpcyBtb2RlbCBpcyBhIGRlZXAgbGVhcm5pbmcgbW9kZWwgd2l0aCB0aGUgZm9sbG93aW5nIGxheWVyczoKCjEuIElucHV0IC0gQ29uY2F0ZW5hdGlvbiBvZiB1c2VyIGVtYmVkZGluZyB2ZWN0b3IgYW5kIG1vdmllIGVtYmVkZGluZyB2ZWN0b3IgKHNpemUgPSBgMipuX2ZhY3RvcnNgKQoyLiBEcm9wb3V0IHdpdGggZHJvcG91dCByYXRlIGBwMT0wLjA1YAozLiBMaW5lYXIgZnVsbHkgY29ubmVjdGVkIGxheWVyIHdpdGggb3V0cHV0IHNpemUgYG5oID0gMTBgCjQuIFJlbHUKNS4gRHJvcG91dCB3aXRoIGRyb3BvdXQgcmF0ZSBgcDI9MC41YAo2LiBMaW5lYXIgZnVsbHkgY29ubmVjdGVkIGxheWVyIHdpdGggb3V0cHV0IHNpemUgMS4KNy4gQXBwbHkgc2lnbW9pZCBmdW5jdGlvbiBhbmQgc2NhbGUgdG8gYmUgYmV0d2VlbiBtaW4gYW5kIG1heCByYXRpbmcKCmBgYHtyfQpweV9ydW5fc3RyaW5nKCIKY2xhc3MgRW1iZWRkaW5nTmV0KG5uLk1vZHVsZSk6CiAgICBkZWYgX19pbml0X18oc2VsZiwgbl91c2Vycywgbl9tb3ZpZXMsIG5fZmFjdG9ycywgbWluX3JhdGluZywgbWF4X3JhdGluZywgbmg9MTAsIHAxPTAuMDUsIHAyPTAuNSk6CiAgICAgICAgc3VwZXIoKS5fX2luaXRfXygpCiAgICAgICAgKHNlbGYudSwgc2VsZi5tKSA9IFtnZXRfZW1iKCpvKSBmb3IgbyBpbiBbCiAgICAgICAgICAgIChuX3VzZXJzLCBuX2ZhY3RvcnMpLCAobl9tb3ZpZXMsIG5fZmFjdG9ycyldXQogICAgICAgIHNlbGYubGluMSA9IG5uLkxpbmVhcihuX2ZhY3RvcnMqMiwgbmgpCiAgICAgICAgc2VsZi5saW4yID0gbm4uTGluZWFyKG5oLCAxKQogICAgICAgIHNlbGYuZHJvcDEgPSBubi5Ecm9wb3V0KHAxKQogICAgICAgIHNlbGYuZHJvcDIgPSBubi5Ecm9wb3V0KHAyKQogICAgICAgIHNlbGYubWluX3JhdGluZyA9IG1pbl9yYXRpbmcKICAgICAgICBzZWxmLm1heF9yYXRpbmcgPSBtYXhfcmF0aW5nCiAgICAgICAgCiAgICBkZWYgZm9yd2FyZChzZWxmLCBjYXRzLCBjb250cyk6CiAgICAgICAgdXNlcnMsbW92aWVzID0gY2F0c1s6LDBdLGNhdHNbOiwxXQogICAgICAgIHggPSBzZWxmLmRyb3AxKHRvcmNoLmNhdChbc2VsZi51KHVzZXJzKSxzZWxmLm0obW92aWVzKV0sIGRpbT0xKSkKICAgICAgICB4ID0gc2VsZi5kcm9wMihGLnJlbHUoc2VsZi5saW4xKHgpKSkKICAgICAgICByZXR1cm4gRi5zaWdtb2lkKHNlbGYubGluMih4KSkgKiAoc2VsZi5tYXhfcmF0aW5nLXNlbGYubWluX3JhdGluZysxKSArIHNlbGYubWluX3JhdGluZy0wLjUKICAgICAgICAgICAiKQpgYGAKCk1vZGVsIGlzIGRlZmluZWQgYW5kIGZpdC4KYGBge3J9CndkPTFlLTUKbW9kZWwgPSBweSRFbWJlZGRpbmdOZXQobl91c2Vycywgbl9tb3ZpZXMsIG5fZmFjdG9ycywgbWluX3JhdGluZywgbWF4X3JhdGluZykkY3VkYSgpCm9wdCA9IHB5JG9wdGltJEFkYW0obW9kZWwkcGFyYW1ldGVycygpLCAxZS0zLCB3ZWlnaHRfZGVjYXk9d2QpCmBgYAoKVGhlIG1vZGVsIGRldGFpbHMgYXJlIGxpc3RlZCBiZWxvdwpgYGB7cn0KbW9kZWwKYGBgCgpgYGB7ciByZXN1bHRzID0gImhpZGUifQptZGxmaXQgPSBweSRmaXQobW9kZWwsIGRhdGEsIDNMLCBvcHQsIHB5JEYkbXNlX2xvc3MpCmBgYApgYGB7cn0KcGFzdGUwKCJNU0Ugb2YgdmFsaWRhdGlvbiBzZXQgPSAiLCByb3VuZChtZGxmaXRbWzFdXSwgMykpCmBgYApMZWFybmluZyByYXRlIGlzIGNoYW5nZWQgYW5kIG1vZGVsIHJlZml0CmBgYHtyIHJlc3VsdHM9ImhpZGUifQpweSRzZXRfbHJzKG9wdCwgMC4wMSkKbWRsZml0ID0gcHkkZml0KG1vZGVsLCBkYXRhLCAzTCwgb3B0LCBweSRGJG1zZV9sb3NzKQpgYGAKYGBge3J9CnBhc3RlMCgiTVNFIG9mIHZhbGlkYXRpb24gc2V0ID0gIiwgcm91bmQobWRsZml0W1sxXV0sIDMpKQpgYGAKCiMjIFN1bW1hcnkKClRoaXMgc2hvd3MgaG93IGEgbW9kZWwgY2FuIGJlIGRldmVsb3BlZCBmcm9tIHNjcmF0Y2guIEJ1dCB0aGUgbW9kZWwgd291bGQgbmVlZCB0byBiZSBkZWZpbmVkIGFzIHB5dGhvbiBjbGFzcy4gSSBhbSBndWVzc2luZyBpZiBzb21lYm9keSBpcyBkZXZlbG9waW5nIGEgbW9kZWwgZnJvbSBzY3JhdGNoLCBpdCBtaWdodCBiZSBiZXR0ZXIgdG8gZG8gaXQganVzdCBpbiBweXRob24gYW5kIGNyZWF0ZSB3cmFwcGVyIGZ1bmN0aW9ucyB3aGljaCBjYW4gdGhlbiBieSB1c2VkIGJ5IFIuIAoKYGBge3J9CnNlc3Npb25JbmZvKCkKYGBg