Apache MXNet - KVStore and Visualization



This chapter deals with the python packages KVStore and visualization.

KVStore package

KV stores stands for Key-Value store. It is critical component used for multi-device training. It is important because, the communication of parameters across devices on single as well as across multiple machines is transmitted through one or more servers with a KVStore for the parameters.

Let us understand the working of KVStore with the help of following points:

  • Each value in KVStore is represented by a key and a value.

  • Each parameter array in the network is assigned a key and the weights of that parameter array is referred by value.

  • After that, the worker nodes push gradients after processing a batch. They also pull updated weights before processing a new batch.

In simple words, we can say that KVStore is a place for data sharing where, each device can push data in and pull data out.

Data Push-In and Pull-Out

KVStore can be thought of as single object shared across different devices such as GPUs & computers, where each device is able to push data in and pull data out.

Following are the implementation steps that needs to be followed by devices to push data in and pull data out:

Implementation steps

Initialisation − First step is to initialise the values. Here for our example, we will be initialising a pair (int, NDArray) pair into KVStrore and after that pulling the values out −

import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())

Output

This produces the following output −

[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]

Push, Aggregate, and Update − Once initialised, we can push a new value into KVStore with the same shape to the key −

kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())

Output

The output is given below −

[[8. 8. 8.]
 [8. 8. 8.]
 [8. 8. 8.]]

The data used for pushing can be stored on any device such as GPUs or computers. We can also push multiple values into the same key. In this case, the KVStore will first sum all of these values and then push the aggregated value as follows −

contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())

Output

You will see the following output −

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

For each push you applied, KVStore will combine the pushed value with the value already stored. It will be done with the help of an updater. Here, the default updater is ASSIGN.

def update(key, input, stored):
   print("update on key: %d" % key)
   
   stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())

Output

When you execute the above code, you should see the following output −

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

Example

kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())

Output

Given below is the output of the code −

update on key: 3
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

Pull − As like Push, we can also pull the value onto several devices with a single call as follows −

b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

Output

The output is stated below −

[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

Complete Implementation Example

Given below is the complete implementation example −

import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
   print("update on key: %d" % key)
   stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

Handling Key-Value Pairs

All the operations we have implemented above involves a single key, but KVStore also provides an interface for a list of key-value pairs

For a single device

Following is an example to show an KVStore interface for a list of key-value pairs for a single device −

keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())

Output

You will receive the following output −

update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
 [3. 3. 3.]
 [3. 3. 3.]]

For multiple device

Following is an example to show an KVStore interface for a list of key-value pairs for multiple device −

b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())

Output

You will see the following output −

update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
 [11. 11. 11.]
 [11. 11. 11.]]

Visualization package

Visualization package is Apache MXNet package used to represents the neural network (NN) as a computation graph that consists of nodes and edges.

Visualize neural network

In the example below we will use mx.viz.plot_network to visualize neural network. Followings are the prerequisites for this −

Prerequisites

  • Jupyter notebook

  • Graphviz library

Implementation Example

In the example below we will visualize a sample NN for linear matrix factorisation −

import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')

# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50

# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)

# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)

# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)

# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)

# Visualize the network
mx.viz.plot_network(N_net)
Advertisements