How Tensorflow is used with Estimators to build a linear model to load the titanic dataset?

TensorflowServer Side ProgrammingProgramming

A linear model can be built with estimators to load the titanic dataset using the ‘read_csv’ method which is present in ‘Pandas’ package. This method takes google APIs that store the titanic dataset. The API is read and the data is stored in the form of a CSV file.

Read More: What is TensorFlow and how Keras work with TensorFlow to create Neural Networks?

We will use the Keras Sequential API, which is helpful in building a sequential model that is used to work with a plain stack of layers, where every layer has exactly one input tensor and one output tensor.

A neural network that contains at least one layer is known as a convolutional layer. We can use the Convolutional Neural Network to build learning model. 

TensorFlow Text contains collection of text related classes and ops that can be used with TensorFlow 2.0. The TensorFlow Text can be used to preprocess sequence modelling.

We are using the Google Colaboratory to run the below code. Google Colab or Colaboratory helps run Python code over the browser and requires zero configuration and free access to GPUs (Graphical Processing Units). Colaboratory has been built on top of Jupyter Notebook.

An Estimator is TensorFlow's high-level representation of a complete model. It is designed for easy scaling and asynchronous training.

We will train a logistic regression model using the tf.estimator API. The model is used as a baseline for other algorithms. We use the titanic dataset with the goal of predicting passenger survival, given characteristics such as gender, age, class, etc.

pip install -q sklearn


import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from six.moves import urllib
import tensorflow.compat.v2.feature_column as fc
import tensorflow as tf
print("Load the dataset")
dftrain = pd.read_csv('')
dfeval = pd.read_csv('')
print("Removing feature 'survived'")
y_train = dftrain.pop('survived')
y_eval = dfeval.pop('survived')

Code credit −


Load the dataset
Removing feature 'survived'


  • The required packages are downloaded.
  • The data is downloaded from the API.
  • The column ‘survived’ is deleted.
Updated on 22-Feb-2021 10:43:50