JeoPydy!

Building on the data ingestion pipeline I showed last time, today I went a bit further and started to build out a website based on that data! It’s still very early days, but I did get far enough through to have a functioning JSON endpoint. Here are the steps that I went through to get to an initial basic version of JeoPydy!

Database enhancements

My first step was to finish the data ingestion pipeline by adding constraints to the schema. I did this by defining the schema explicitly in SQL instead of relying on the auto-generated schema from DataFrame.to_sql(), which gives me greater control over the resulting database:

DROP TABLE IF EXISTS category;
DROP TABLE IF EXISTS question;
DROP TABLE IF EXISTS round;
DROP TABLE IF EXISTS show;
CREATE TABLE round (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  round_name TEXT NOT NULL
);
CREATE TABLE show (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  air_date TEXT NOT NULL
);
CREATE TABLE category (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  round_id INTEGER NOT NULL,
  show_id INTEGER NOT NULL,
  name TEXT NOT NULL,
  FOREIGN KEY (round_id) REFERENCES round (id),
  FOREIGN KEY (show_id) REFERENCES show (id)
);
CREATE TABLE question (
  id INTEGER PRIMARY KEY AUTOINCREMENT,
  category_id INTEGER NOT NULL,
  value INTEGER,
  question TEXT NOT NULL,
  answer TEXT,
  FOREIGN KEY (category_id) REFERENCES category (id)
);

The key things I’m doing here are:

  • Types: Data frame columns are also typed, so these would have been mapped across to the SQL table anyway, but I still think it’s good to have more explicit control over these, especially since my types are not dynamic
  • Constraints: I have set the id columns to be autoincremented primary keys, and I have also set foreign key constraints for the relevant columns to ensure referential integrity, i.e. that a record I create in a linked table has a valid reference to the foreign table. The foreign key constraint is useful to make sure I’m not accidentally creating any orphaned rows during the ingestion pipeline, but the autoincremented primary key will only really be useful if I start adding data to the database through a separate method.

It’s worth saying I could probably use SqlAlchemy to do this instead of writing pure SQL, but that’s a project for another time.

Web API

My next step was to write the actual web API. The eventual aim of the project is to write a website that will allow interactions, but my front-end experience is limited, so it makes sense to start with what I know. The HTML and CSS trickery can come later.

In previous projects, I have used the base http.server1https://docs.python.org/3.9/library/http.server.html provided in core Python, but I wouldn’t do it again. Recent versions of python (3.7+ I think) have added a new class called ThreadingHTTPServer that serves each request in a separate thread which results in much more performant code when dealing with concurrent requests, but there’s still a whole lot of plumbing that’s needed just to write a server that can do basic GET calls. And then if you want anything like authentication, SSL, session support, etc., you have to manually implement it. It’s kind of fun manually implementing the code for how to handle every method on every endpoint, but only up to a point…

The alternative is the Flask2https://flask.palletsprojects.com/en/1.1.x/ framework, and the major advantage is that does all this stuff for you. I don’t have a whole lot of experience with Flask, but I have previously worked my way through the excellent Flask tutorial3https://flask.palletsprojects.com/en/1.1.x/tutorial/ which does a great job of introducing the concepts and philosophy of Flask in the context of a real application. So many package tutorials I have worked through neglect to do this, and it makes the learning curve a lot steeper to implement the package in a real project. Having a tutorial that both follows best practices and results in a real production-ready application makes it so much easier to understand how to apply the package to your specific requirements, as opposed to just what it does from a functionality perspective. It’s so good, it’s also a great place to start from if you are writing your own code, which I unashamedly did here. It’s released under the 3-clause BSD license4https://flask.palletsprojects.com/en/1.1.x/license/, so there’s no legal or ethical issues with doing this either.

The relevant parts of the code are the application factory, database initialization function, and then a single blueprint for accessing the data and returning a dictionary as a json response:

import os
from flask import Flask


def create_app(test_config=None):
    # create and configure the app
    app = Flask(__name__, instance_relative_config=True)
    app.config.from_mapping(
        SECRET_KEY='dev',
        DATABASE=os.path.join('jeopydy', 'jeopydy.db'),
    )

    # ensure the instance folder exists
    try:
        os.makedirs(app.instance_path)
    except OSError:
        pass

    from . import db
    db.init_app(app)

    from . import round
    app.register_blueprint(round.bp)

    return app
import sqlite3
import click
from flask import current_app, g
from flask.cli import with_appcontext
import jeopydy.build_database


def get_db():
    if 'db' not in g:
        g.db = sqlite3.connect(
            current_app.config['DATABASE'],
            detect_types=sqlite3.PARSE_DECLTYPES
        )

        g.db.row_factory = sqlite3.Row
    return g.db


def close_db(e=None):
    db = g.pop('db', None)
    if db is not None:
        db.close()


def init_db():
    db = get_db()

    with current_app.open_resource('schema.sql') as f:
        db.executescript(f.read().decode('utf8'))

    jeopydy.build_database.build(db)


@click.command('init-db')
@with_appcontext
def init_db_command():
    """Clear the existing data and create new tables."""
    init_db()
    click.echo('Initialized the database.')


def init_app(app):
    app.teardown_appcontext(close_db)
    app.cli.add_command(init_db_command)
from flask import (
    Blueprint, flash, g, redirect, render_template, request, url_for
)
from werkzeug.exceptions import abort
import datetime

from jeopydy.db import get_db

bp = Blueprint('round', __name__)


def get_round():
    db = get_db()
    round = db.execute(
        "SELECT c.name, s.air_date, value, question, answer "
        "FROM question as q "
        "INNER JOIN category c on c.id = q.category_id "
        "INNER JOIN show s on s.id = c.show_id "
        "WHERE q.category_id in ("
            "SELECT c.id "
            "FROM category as c "
            "WHERE c.round_id == 0 "
            "ORDER BY RANDOM() LIMIT 6) "
    ).fetchall()

    response = dict()
    for category_name, air_date, value, question, answer in round:
        category_date = f"{category_name} " \
                        f"({datetime.datetime.strptime(air_date, '%Y-%m-%d %H:%M:%S').strftime('%B %-d, %Y')})"
        if category_date not in response:
            response[category_date] = dict()
        response[category_date][value] = question, answer
    return response


@bp.route('/api')
def api():
    return get_round()

The tutorial referenced above gives a thorough run-through of what each line/block does, so I’ll only cover the parts specific to my project here.

__init__.py

The presence of an __init__.py file denotes the folder it is in as a package, and the code within that file is executed whenever that package is imported. In the world of Flask, that means the Flask application can be written as its own package, which helps in managing and organizing the code, e.g. separate parts of the application can be encapsulated into separate modules, allowing a clearer separation of concerns. My application is nowhere near that scale yet, but as I start adding the concept of users, PvP games, etc., it probably will be.

The code itself is pretty basic, it just implements a factory pattern to create the Flask app class instance (JeoPydy – every Python application needs a witty name, right?) and registers my API blueprint defined below.

db.py

The db.py file is responsible for creating the database if it doesn’t already exist, and managing the connection to the database. Here I import the module I wrote in the previous post that actually processes and imports the data, but first I execute the SQL schema above to create the tables with the appropriate constraints.

One really cool feature here is the @click.command('init-db') decorator. The Click library5https://click.palletsprojects.com/en/7.x/ lets you turn a function into a command-line application with barely any code. In this case, adding the decorator to the init_db_command() function lets us do flask init-db from the command line, which automatically set up the Flask app and initializes the database, without us having to worry about calling Python and calling it with a specific function of a specific module. Beautiful.

round.py

Finally, round.py actually defines the API endpoint that returns the questions and answers for a Jeopardy round. The query is included in the Python code above, but I’ll break it out here separately:

SELECT c.name, s.air_date, value, question, answer
FROM question as q
INNER JOIN category c on c.id = q.category_id
INNER JOIN show s on s.id = c.show_id
WHERE q.category_id in (
    SELECT c.id
    FROM category as c
    WHERE c.round_id == 0
    ORDER BY RANDOM() LIMIT 6

The main query selects a set of rows from the question table, along with the Category Name (very important!) and original Air Date (sometimes important for contemporaneous reasons) for each. However, the query needs to limit the set of rows to all of those linked to 6 randomly selected Jeopardy (or Double Jeopardy) categories. I achieved this through a sub-query, where I first randomly select 6 categories and then feed this into the WHERE statement for my main query.

Performance for this is pretty good, it only takes around 50 ms from an SQLite database. Before this goes into production I’ll switch to a more scalable database engine, I have the most experience with MS SQL Server, and since it’s now available in Linux (and there’s an official Docker image6https://hub.docker.com/_/microsoft-mssql-server – good job Microsoft!), I might give that a go. However, the current situation is fine for this stage of the project. If performance becomes an issue later on I’ll look to denormalize the database to avoid the need for the sub-query, but it’s working well enough right now.

The rest of the round.py code sets up the blueprint. The blueprint contains a single function mapped to the '/api' path, which returns the results of the SQL query above transformed into a JSON payload. Another great feature of Flask is that I don’t need to worry about serializing the JSON before I return it, I just return the dict and Flask figures out what to do automatically!

Response

Here’s the response from the endpoint:

{
  "BUSINESS & INDUSTRY (December 23, 2002)": {
    "200": [
      "This company serving about 30 million customers features the following in its ads [(actor walking across the <i>Jeopardy!</i> stage)  <i>Can you hear me now?  Good!  Can you hear me now?  Good!</i>]",
      "Verizon Wireless"
    ],
    "400": [
      "It was once described as 2 \"beautifully embossed, choc. flavored wafers with a rich cream filling\"",
      "Oreo cookies"
    ],
    "600": [
      "This Australian-born moguls's first U.S. newspaper acquisition was the San Antonio Express & News",
      "Rupert Murdoch"
    ],
    "800": [
      "The \"K\" in K-Mart's name stands for this founder",
      "S.S. Kresge"
    ],
    "1000": [
      "The lion on this French auto company's logo was adopted from Belfort, the city where an early model was made",
      "Peugeot"
    ]
  },
  "EPITAPHS (March 26, 1999)": {
    "100": [
      "\"Author of the Declaration of American Independence...\"",
      "Thomas Jefferson"
    ],
    "200": [
      "\"Quoth the raven nevermore\"",
      "Edgar Allan Poe"
    ],
    "300": [
      "\"Truth and history.  21 men.  The boy bandit king -- he died as he had lived\"",
      "Billy the Kid"
    ],
    "500": [
      "\"Man of 1000 Voices\"",
      "Mel Blanc"
    ],
    "1000": [
      "\"61/61.  Against all odds\"",
      "Roger Maris"
    ]
  },
...
}

The real response contained 6 categories, but I’ve truncated it to 2 categories here to save space. Something interesting though is that some categories have the current values ($200 to $1,000 in $200 increments), but some questions have half that ($100 to $500 in $100 increments), and some have some kind of hybrid, with $100 to $300 in $100 increments, and then the final two questions being worth $500 and $1,000. I might want to normalize these at some point, possibly by transforming all questions that were aired before the date on which Jeopardy settled on the current prize money structure.

Another quirk of the dataset is that sometimes there are missing question/answer pairs in a given category. I assume this is because sometimes those questions are never revealed; sometimes the time expires, and so the question may never be revealed, and therefore never makes it onto http://www.j-archive.com/. I’ll have to decide what I want to do with these, it may be that I drop them from the initial database, or maybe flag the category as incomplete so they can be filtered out when the query runs.

Summary & Next Steps

I’m pretty pleased with how far I managed to get in maybe 1 or 2 hours! The code samples above include a few firsts for me:

  • Defining a schema in SQL, including ensuring referential integrity
  • Using a nested SQL query
  • Using Flask to build an API

The surprising thing is, none of this was really that difficult! I think it’s a testament to the Flask framework for eliminating the need to write a whole load of boilerplate code and having a great tutorial, but it also gives me a sense of confidence that this stuff isn’t really that hard, I just need an engaging project as a reason to learn it.

The next steps I want to accomplish are to serve this up using a self-hosted WSGI server and then to write a website using the Jinja7https://jinja.palletsprojects.com/en/2.11.x/ templating language. Hopefully, the next article will include some screenshots of a real interactive Jeopardy board!

Jeopardy!

Hi everyone! I know I said I was going to start work on a weight tracking app next, but I came up with something a bit more fun! I was crawling through some of the fascinating datasets on Kaggle looking for some inspiration, and found a set of over 200,000 Jeopardy questions!1https://www.kaggle.com/tunguz/200000-jeopardy-questions.

My plan is to use this to learn a little more about front-end stuff, which is still the area I’m weakest in the software development world. My plan is to build a page that generates random combinations of previous Jeopardy games, potentially allowing two people to play against each other in some kind of shared session.

First things first though, I need to get the raw data from Kaggle into a format that I can use in my front end. Python seems like a sensible tool to do that, and the Pandas library should make that process easy. The first thing I did was to write a simple script that would suck the data out of the CSV and dump it into a SQLite database. Thanks to the magic of Pandas and SqlAlchemy, the code is alarmingly simple…

import pandas as pd
from sqlalchemy import create_engine

df_questions = pd.read_csv('JEOPARDY_CSV.csv', parse_dates=[1])
engine = create_engine('sqlite:///questions.db', echo=False)
df_questions.to_sql('questions', con=engine, if_exists='replace')

And that’s it! Done! Time for a pint! Pandas automatically reads the CSV (including headers), parses the dates into real DateTime objects, and then SqlAlchemy creates a SQLite database file and dumps the data into the database. 5 lines including import statements, even for Python that’s impressive.

However, there are some problems with this. The first is that the column names in the CSV have spaces all over the place, which is not ideal since both SQL and Pandas make you jump through all sorts of hoops to address columns with spaces in the name. However, that’s an easy fix with a renaming dict, which turns the code above into something like this:

import pandas as pd
from sqlalchemy import create_engine

column_rename = {"Show Number": "show_identity",
                 " Air Date": "air_date",
                 " Round": "round_name",
                 " Category": "category",
                 " Value": "value",
                 " Question":"question",
                 " Answer": "answer"}

df_questions = pd.read_csv('JEOPARDY_CSV.csv', parse_dates=[1])
df_questions.rename(columns = column_rename, inplace=True)
engine = create_engine('sqlite:///questions.db', echo=False)
df_questions.to_sql('questions', con=engine, if_exists='replace')

OK, so I’m now up to 7 lines. Still not bad. However, another problem remains, and that is the normalization of the data (or lack of). I’ll talk through the pros and cons of normalizing a dataset in a moment, but first, he’s a look at the current state of the database:

show_identityair_dateround_namecategoryvaluequestionanswer
54622008-05-13Double Jeopardy!Who’s Your Mommy?$800King Solomon
54622008-05-13Double Jeopardy!It’s all from the Greek to me$800An institution for the care…
54622008-05-13Double Jeopardy!Dr. Drew$1200People 18 to 25 are vulnerable to…
54622008-05-13Double Jeopardy!Recent Films$1200The attempted assassination…
54622008-05-13Double Jeopardy!Music Class$2000This large woodwind…
54622008-05-13Double Jeopardy!Lend me you Iroquois$1200The Iroquois Nationals…
54622008-05-13Double Jeopardy!Who’s Your Mommy?$1200Invention (proverbially)

This is a random sample of 7 rows from the dataset, and the lack of normalization is pretty obvious. From the 7 columns shown here, only two of them contain unique values (Question and Answer), 2 columns contain repeated values, and 3 columns contain all identical values! In fact, from the 216,930 unique questions in this dataset, there are only 3,640 unique shows and 47,200 unique categories. So there’s certainly a huge scope for normalization here!

I’ll go through the normalization steps I went through in a moment, but as a starting point, here’s the schema diagram for the current unnormalized table:

The data as pulled from Kaggle doesn’t even really conform to the Unnormalized Form2https://en.wikipedia.org/wiki/Unnormalized_form, since it doesn’t have an explicitly defined primary key (however Pandas was kind enough to create one for me automatically when I loaded the CSV). It does however have no duplicate rows, so it’s halfway there.

Normalizing The Dataset

Typically, before I start writing any database code I want to design my schema to figure out how I’m going to represent the data. I design relational databases for a living, and I’ve developed a pretty simple approach to the initial coarse-grained normalization of a database schema.

The trick is to figure out what each table conceptualizes.

Throughout my career, I have typically been involved in designing and building databases to store engineering data. If I need to build a database to store mechanical testing data (for example), I’m probably going to need tables about the material being tested (both the specification of the material and the specific instance of the material), the machine that’s doing the testing, the results of the test, the calculations performed on those results, etc.

In this case, the key Jeopardy concepts I came up with were:

  • The Show (show number, air date)
  • The Rounds (Jeopardy!, Double Jeopardy!, Final Jeopardy!)
  • The Categories (History, Brass, Birds, “T”elevision)
  • Question & Answers

Everything associated with Jeopardy (except for the contestants and Aaaaalex Trebek!) fits into one of those buckets, and so that sounds like a pretty good way to go about designing the normalized database. To formalize things up a bit, here it is in UML:

This feels like the appropriate way to construct the database since it relates the key concepts of a Jeopardy game in a way that makes conceptual sense, e.g. multiple categories are included in a single show, and each category contains multiple questions and answers. Each question and answer pair belongs to a single category, and has a specific ‘value’.

However, I want to be sure I’ve considered the most significant possible ways of normalizing the data, so I’ll now go through the first three normal forms to make sure the database either adheres to that form, or that I have a justification for not doing so.

1st Normal Form (1NF)

The First normal form simply states that the columns in a table should be atomic (i.e. indivisible). We were actually already there with the initial CSV, so we pass this first test straight off. Great!

2nd Normal Form (2NF)

The Second normal form get’s a bit trickier. To pass this one, we need to ensure that every column that is not part of a candidate key depends on the entire candidate key, and not just a part of it. A common ‘smell’ associated with breaking 2NF is that there’s a lot of duplicated data.

Our original database didn’ satisfy 2NF all over the place. A candidate key for our original database was the Question itself, but everything except the answer doesn’t depend on the question. Pulling the Question, Answer, and Value columns into a separate table helps with this, and then doing the same for the category name ensures all tables are compliant with 2NF.

3rd Normal Form (3NF)

The Third normal form states that every non-prime column is non-transitively dependent on every key in that table, i.e. there should not be a column whose values depend on something other than a key. Again, this is something that my schema satisfies due to the creation of separate tables for each entity.

For example, in the original dataset, the air date of the show is functionally dependent on the show_identity and not on the question candidate key. Therefore there’s nothing to stop the same episode identity from having multiple different air dates. This is a violation of 3NF.

The normalized schema creates a dedicated Show table with the show identity, and the air_date is dependent purely on this column.

Pros and Cons of Excessive Normalization

Before I go on to the code that performs the nomalization, I want to address the question of whether too much normalization is a bad thing. The answer in my experience is that you have to determine that on a case-by-case basis, based on both the data that you are storing and the way that data is being used. It’s pointless having a theoretically perfect database if you have to perform 30 joins that inur a huge performance penalty, and even worse if the front-end tool you are using just can’t do the joins that you need!

With that being said, I always like to start with an over-normalized database since during the implementation phase of a project it’s generally easier to denormalize a database than it is to normalize a database. Also, you only realize what you need to denormalize once you actually try to use the database. So I think it’s always better to over-normalize early on in a project, and then denormalize as the front-end comes together and additional constraints become clear.

The Code

The code to normalize the database makes extensive use of copying subsets of DataFames in Pandas, dropping columns, removing duplicates, and mapping values. Most of it was pretty straight forward, but one particular area that required some figuring out was where I generate the category table. This step required creating an entirely new identity that was dependent on the category name, the round, and the show (since category names could be re-used between shows). My original approach to doing this was to:

  1. Copy the questions DataFrame to a ‘category’ DataFrame
  2. Drop the columns I didn’t want (value, question, answer)
  3. Delete duplicate rows
  4. Re-generate the DataFrame index
  5. Find the category identity (index) based on each round, category, show tuple in the questions DataFrame
# Create a categories table
df_categories = pd.DataFrame(columns=['show_identity', 'round_name', 'name'])

df_categories['show_identity'] = df_questions['show_identity'].to_numpy()
df_categories['round_name'] = df_questions['round_name'].to_numpy()
df_categories['name'] = df_questions['category'].to_numpy()
df_categories.drop_duplicates(inplace=True, ignore_index=True)

def get_category_id(question):
    show_identity = question.show_identity
    round_name = question.round_name
    category = question.category

    category_id = df_categories[(df_categories.round_name == round_name) &
                                       (df_categories.name == category) &
                                       (df_categories.show_identity == show_identity)].index[0]
    return category_id

# Get the category ID for each question
category_ids = df_questions.swifter.apply(get_category_id, axis=1)

# Add the category identity to the questions table, and delete unneeded columns
df_questions.insert(loc=0, column='category_identity', value=category_ids)

This just felt wrong when I was writing it; doing 200,000 lookups across three columns in a data frame is never going to be efficient, and even with the swifter3https://github.com/jmcarpenter2/swifter library that parallelized the operation across my 12-core Ryzen it still took over 20 minutes to complete.

The better approach was to create the category index before I duplicated the DataFrame, that way the cross-referencing was already done without having to resort to a whole load of lookups. The steps to do this were as follows:

  1. Create the category DataFrame and copy the relevant columns
  2. Create the composite key based on those relevant columns and create a map between those values and an int
  3. Map the composite key in the category DataFrame
  4. Map the composite key in the question DataFrame, and drop the columns that were moved to the category DataFrame

This cut the time taken to build the category table from around 20 minutes to about 10 seconds, a factor of improvement of about 120! You can see that code below, starting from the comment line # Create a categories table.

import pandas as pd
from sqlalchemy import create_engine
import numpy as np

column_rename = {"Show Number": "show_identity",
                 " Air Date": "air_date",
                 " Round": "round_name",
                 " Category": "category",
                 " Value": "value",
                 " Question":"question",
                 " Answer": "answer"}

df_questions = pd.read_csv('JEOPARDY_CSV.csv', parse_dates=[1])

df_questions.rename(columns = column_rename, inplace=True)

# Turn the value into an integer
df_questions['value'] = df_questions['value'].map(
    lambda a: int((a[1:].replace(',', '')))
        if a != 'None'
    else np.NaN)

# Create a shows table, and delete corresponding columns in questions table
df_shows = pd.DataFrame(columns=['identity', 'air_date'])
df_shows['identity'] = df_questions['show_identity'].to_numpy()
df_shows['air_date'] = df_questions['air_date'].to_numpy()
df_shows.drop_duplicates(inplace=True, ignore_index=True)
df_shows.set_index('identity', inplace=True, append=False)
df_questions.drop('air_date', axis=1, inplace=True)

# Create a categories table
df_categories = pd.DataFrame()

df_categories['show_identity'] = df_questions['show_identity'].to_numpy()
df_categories['round_name'] = df_questions['round_name'].to_numpy()
df_categories['category'] = df_questions['category'].to_numpy()

# Define the composite key that uniquely identifies a category
category_composite_key = ['show_identity', 'round_name', 'category']
df_categories.drop_duplicates(inplace=True, ignore_index=True)

old_category_identities = df_categories[category_composite_key].apply(lambda row: '_'.join(row.values.astype(str)),
                                                                      axis=1)
new_category_identities = range(len(old_category_identities))
category_identity_mapping = dict(zip(old_category_identities, new_category_identities))


def map_category_identity(category):
    composite_key = '_'.join(category.values.astype(str))
    return category_identity_mapping[composite_key]


df_categories['identity'] = df_categories.apply(map_category_identity, axis=1)
df_categories.rename(columns={'category': 'name'}, inplace=True)
df_categories.index.rename('identity', inplace=True)
df_categories.set_index(keys=['identity'], append=False, inplace=True, drop=True)

# Perform the same mapping on the questions themselves, and drop the individual elements of the key
df_questions.insert(0, 'category_identity', value=df_questions[category_composite_key].apply(map_category_identity,
                                                                                             axis=1))
df_questions.drop(category_composite_key, axis=1, inplace=True)
df_questions.index.rename('identity', inplace=True)

# Create a rounds table
round_name_mapping = {'Jeopardy!': 0, 'Double Jeopardy!': 1, 'Final Jeopardy!': 2, 'Tiebreaker': 3}
df_rounds = pd.DataFrame.from_dict({'round_name': round_name_mapping.keys()})
df_rounds.index.rename('identity', inplace=True)

round_identity = df_categories.apply(lambda c: round_name_mapping[c.round_name], axis=1)
df_categories.insert(loc=0, column='round_identity', value=round_identity)
df_categories.drop('round_name', axis=1, inplace=True)

# Export to SQLite
engine = create_engine('sqlite:///questions.db', echo=False)
df_questions.to_sql('questions', con=engine, if_exists='replace')
df_shows.to_sql('shows', con=engine, if_exists='replace')
df_categories.to_sql('categories', con=engine, if_exists='replace')
df_rounds.to_sql('rounds', con=engine, if_exists='replace')


Referential Integrity

The code above successfully creates the different tables in SQLite with the right columns, which allow all the various joins to be performed to find out which questions appeared on a certain show (through the category table), or ultimately even to re-create the original CSV.

However, the SQLite database doesn’t have any constraints to ensure that any new data added to the database satisfy the requirements to be able to do the joins described above. And unfortunately, since I’m using SQLite at the moment, there’s no simple way to apply constraints to tables in the database if you are creating them with dataframe_to_sql()4https://www.thetopsites.net/article/53283866.shtml. Since this is still in the prototype stage it’s not really critical that I handle this now, but before going into production I’d probably change this to a MySQL or Postgres backend, and at that point, I’d want to make sure my referential integrity is all squared away.

Next Steps

As I mentioned at the top, my goal for this is to be able to create something in a front-end that maybe lets me play a randomly generated game of jeopardy, potentially with other people in some kind of shared session over the web.

However, the first step in that is going to be building a REST API that can return a random Jeopardy ‘game’ when requested. I’ll probably implement the first pass in Flask, but I also want to give some other languages a try; I’ve done some very basic work with Spring Boot in Java, and I think it would make a lot of sense to use that here as well. Maybe have the backend in Spring Boot and the front-end in a Flask-based web app…

Anyway, watch this space and see what I come up with!

The Surveying Problem Part 6: Stack-based DFS

Intro

As I mentioned in my previous post, the final approach I want to try is a stack-based DFS implementation. This will certainly be faster than any of the graph library-based approaches, and based on Python’s performance issues around recursive algorithms, might be the fastest approach of all.

The Code

First things first, here’s the code:

""" A recursive implementation of the surveying problem
find_reservoirs: Determine the number and location of contiguous reservoirs in a grid
get_neighbors: Find the number of active neighbor elements for a given element
"""
from tools import get_neighbors

METHOD_NAME = "Stack Method (List)"


def find_reservoirs(this_grid):
    """ Recursively determines how many wells are needed, making the assumption that
    only one well is needed per contiguous field

    this_grid: This is the list of locations to be checked for the current reservoir
    reservoir: If being called recursively, this contains the current reservoir that is being built
    original_grid: If being called recursively, this is the full grid to find neighbor elements
    """

    checked_elements = set()
    stack = list()
    reservoirs = []

    remaining_nodes = this_grid

    while remaining_nodes:
        reservoir = set()
        stack.append(remaining_nodes.pop())

        while stack:
            location = stack.pop()

            if location in checked_elements:
                continue

            reservoir.add(location)
            checked_elements.add(location)

            stack.extend(get_neighbors(location, this_grid))

        reservoirs.append(reservoir)
        remaining_nodes -= reservoir

    return reservoirs

It’s pleasingly concise, especially considering it’s not calling any external libraries. It’s essentially a standard DFS, but since I’m working with a forest instead of a single tree I wrap the main loop iterating through the stack (the tree) with another while loop that iterates through all remaining nodes (the forest). This outer loop ‘seeds’ the stack with an arbitrary item (I just pick the first one in the set), and then when the stack is empty I remove the entire tree from the forest with a set subtraction. Once I’ve processed the final tree, it is subtracted from the forest and the outer while loop ends.

One design decision I made at this stage was to use a list as my stack. Python does have a deque1https://docs.python.org/3.8/library/collections.html#collections.deque object that I could have used instead, which is optimized for fast appends and pops and has O(1) performance compared to O(n) performance for lists. I’ll look at that in the next article though.

Results

As I said last time, I wanted to dive into matplotlib to generate some plots of the data I have been generating. So here’s a bar chart that shows the relative performance of all the different implementations for a range of grid sizes and probability thresholds.

The takeaway from this chart is that the stack method is the fastest implementation for all scenarios, but it’s very close between the stack and recursive approaches. This isn’t surprising, we already know that the time taken to figure out which sites are near which other sites is the dominant factor in all methods. Any performance improvement on the recursive side of things is going to be limited.

Next Steps

The chart above is a sneak preview of some of the investigations I have started doing into the relative performance of the different DFS algorithms based on grid size and probability threshold. Grid size is a pretty simple scalar, in that beyond a certain point the number of trees scales linearly with the number of available locations. However, the probability threshold is a bit more interesting since it interacts with the clustering algorithm to very rapidly generate very dense forests at still relatively low thresholds.

The next article will be the final one in the series, and will dive into more detail on some of the interesting mathematical features of this problem, and come up with a final determination on the best solution, including looking into list vs deque implmentations of the stack algorithm.

The Surveying Problem Part 5: Graph-Tool

Background

As I mentioned in my previous posts, I wanted to use a graph module in Python to implement a depth-first search (DFS) to solve the surveying problem. My first go was with NetworkX, which is a pure Python implementation that offers a lot of Graph theory-based functionality. However, because it’s pure Python and has the added overhead of importing more modules and running more functions, it’s a lot slower than the solution I implemented.

An alternative is the graph-tool package for Python, which offloads a lot of the heavy lifting to C++. According to the graph-tool website, this results in a much faster implementation compared to NetworkX1https://graph-tool.skewed.de/performance. However, it certainly wasn’t a simple swap-out from NetworkX to graph-tool; I couldn’t even get the library to compile! The folks in the AUR2https://aur.archlinux.org/packages/python-graph-tool/ do warn against parallelizing the build process because of memory consumption, but I was running out of RAM even on a single thread. Luckily, the maintainer of graph-tool maintains a docker image for graph-tool3https://hub.docker.com/r/tiagopeixoto/graph-tool, which gave me a great opportunity to learn docker!

Learning Docker

I have known about docker for a few years but assumed it was just some magic version of virtualization, and never had a reason to dig into it in any more detail than that. I initially tried to just grab the graph-tool image and throw it into docker, but I ran into all sorts of problems around adding other Python packages to the image and getting PyCharm to see Python within the image. This all ultimately stemmed from the fact that I fundamentally didn’t understand docker, specifically the differences between images and containers.

In years gone by I would have either just read the manual or the free tutorials, or just tried to hack through and figure it out myself. However, I’ve reached the point in my life where I realize that if I can spend a few dollars on something that will save me a few hours, then that’s money very well spent! I found a course on Udemy4https://www.udemy.com/course/docker-mastery that I managed to get for $20, but I would have spent $100 on it, and it’s been excellent at explaining the concepts in an easy to understand way. I can’t recommend it enough.

I’m not going to go into the details of what I did with docker, but briefly, I created my own image built on the docker image I linked to above with the additional Python packages I needed for my code. Luckily the original image was built on Arch Linux (which I am very familiar with), so it was pretty straightforward to add a RUN line using pacman to add these packages. Using the Python interpreter within the docker container was also straightforward thanks to PyCharm Professional, which I was able to snag for 6 months as part of a Python Humble Bundle earlier this year.

The Code

Bit more pre-amble than previously for this post, so much so that the code almost seems like an afterthought! Anyway, here’s the code:

""" A graph-tools implementation of the surveying problem
find_reservoirs: Determine the number and location of contiguous reservoirs in a grid
"""

import graph_tool as gt
import graph_tool.topology as topology

from tools import get_neighbors

METHOD_NAME = "Graph-Tool Method"


def find_reservoirs(locations):
    """ Uses a graph approach to find how many wells are needed, making the assumption that
    only one well is needed per contiguous field

    locations: Set containing all locations with oil
    """
    locations = {location: idx for idx, location in enumerate(locations)}

    locations_graph = gt.Graph()
    locations_graph.set_directed(False)

    locations_graph.add_vertex(len(locations))
    locations_prop = locations_graph.new_vertex_property("object", locations)

    edge_list = []
    for location in locations:
        neighbor_coords = get_neighbors(location, locations)
        edge_list.extend([(locations[location], locations[neighbor])
                          for neighbor in neighbor_coords])

    locations_graph.add_edge_list(edge_list)

    components, _ = topology.label_components(locations_graph, directed=False)
    wells = dict()
    for vertex, label in enumerate(components.a):
        if label not in wells:
            wells[label] = []
        wells[label].append(locations_prop[vertex])

    return wells.values()

It’s logically identical to the NetworkX approach, first we create the verticies, then we create the edges between the connected vertices, and then we find our connected sub-graphs in our grid.

The thing that struck me most about graph-tool is that it’s extremely unpythonic. I suppose this shouldn’t be surprising considering it’s really just a Python wrapper on top of C++ code, but that wrapper is very very thin. There aren’t really any helper functions, so you have to do a lot of stuff yourself in terms of getting data into the correct format for graph-tool and post-processing it’s results. This means graph-tool code is about 50% longer than the equivalent NetworkX code. I suppose one can argue that if performance is so important to you that you have gone to the effort to implement graph-tool (and all the compilation pain it brings), you’d want complete control over these wrapper functions. Still, it would be nice if that wrapper was just a little bit thicker.

Results

The next step is to review performance between the NetworkX and graph-tool implementations. The results are shown in the table below. The biggest takeaway from this is that graph-tool starts out slower than NetworkX for very small grids, and only becomes more performant beyond 100×100. Even then though, it remains slower than the Recursive implementation.

Grid SizeTime (Recursive)Time (NetworkX)Time (graph-tool)
10×105.205E-05 s1.686E-04 s (x3.24 slower)3.620E-04 s (x6.95 slower)
100×1002.856E-03 s1.007E-02 s (x3.53 slower)5.851E-03 s (x2.05 slower)
1000×10004.628E-01 s2.260E+00 s (x4.88 slower)8.253E-01 s (x1.78 slower)
5000×50001.827E+01 s9.253E+01 s (x5.06 slower)2.722E+01 s (x1.49 slower)

I have a pretty good idea about why this is, and that’s that the actual work being done by the graph library is small compared to the work being done to build the graph. A good way to check is to split the times from above by the time spent building the graph (which is approximately the same between both methods), and the time spent solving the graph.

Grid SizeGraph Build Time (NetworkX)Clustering Time (NetworkX)Graph Build Time (graph-tool)Clustering Time (graph-tool)
10×101.259E-04 s (73%)4.376E-05 s (27%)2.628E-04 s (67%)1.266E-04 s (33%)
100×1001.103E-02 s (81%)2.564E-03 s (19%)7.290E-03 s (82%)1.574E-03 s (18%)
1000×10001.785E+00 s (87%)2.626E-01 s (13%)6.128E-01 s (78%)1.688E-01 s (22%)
5000×50007.293E+01 s (84%)1.397E+01 s (16%)2.105E+01 s (83%)4.224E+00 s (17%)

The broad thing to take away from this table is that my hypothesis from above is correct: the clustering time is a small fraction of the total time taken to solve the problem, and so any time saved here is not going to have a significant impact on overall performance. If I was really interested in optimizing performance I’d be better off re-implementing some of the graph creation logic in something like NumPy, or ultimately just re-implementing the code in a different language altogether.

There are some additional nuances that might be interesting to look into in more detail. One is that the ratio between build time to clustering time seems to trend pretty clearly with the build time becoming the larger fraction at larger grids, but the 1000×1000 size specifically seems to buck this trend. It might be to do with the specific amounts of memory I am using, and that I reach a point where memory allocation starts to slow down and my machine starts swapping. I’m not sure though, add that to the list of things to look into.

Next Steps

I’d really like to dig into matplotlib to try plotting some of these results to see what they look like. I might update this post in a few weeks with a plot showing the some of the tables above. My go-to environment for crude plots is usually Excel, but I’m stuck in Linux working on this stuff, so I should just bite the bullet and learn matplotlib.

In terms of algorithms and implementation, one final thing I can think of to try and drive performance further is to re-implement the initial recursive solution as a stack-based DFS algorithm. As I have mentioned before, Python isn’t suited for recursive implementations because there’s such a huge overhead on function calls. So there’s a good chance I could see some performance improvement here. However, regardless of the implementation, I’m still going to be executing the find_neighbors code the same number of times, which I already know is where a lot of the thinking goes on, so there’s likely a pretty high limit to how much performance can be improved.

The Surveying Problem: XSLT Bonus!

Update

So I started looking at graph-tools, but unfortunately, I fell at the first hurdle; I couldn’t even get it to compile. Kept running out of memory at the same point, I suppose I should have gone for more than 16 Gb when I upgraded my computer earlier this year.

The solution is to use a docker image that is helpfully provided by the graph-tools maintainer, but in order to do that, I need to learn docker. Making some great process, but it might be another week or so before I can get my Python code running against containerized Python. So while I do that, here’s a post I wrote a few weeks ago but was saving for a rainy day…

Background

Shortly after the interview where I came up with the DFS solution to the surveying problem I had an amusing thought. The panel said I could use any language I wanted, so I thought to myself, what would the panel have thought if I decided to try and solve the problem in XSLT1https://en.wikipedia.org/wiki/XSLT?

Anyone who knows me knows I have a strange fondness for XSLT. That strange fondness goes back to my first job after my Ph.D., where XSLT was used extensively to transform XML documents. The majority of my colleagues hated that language since it has a trifecta of attributes against it:

  1. It’s functional, as opposed to imperative
  2. It’s very rarely used
  3. It is itself written in XML

I personally always enjoyed writing XSLT; I saw the need to learn a functional programming language as an interesting challenge, and once I got used to thinking in a functional way, the code came naturally. The fact that it was written in XML wasn’t too much of an issue thanks to a decent IDE, the only real issue was that with a (comparatively) small user community, it was often up to me to figure out the best way to do things.

Another important thing to know about XSLT, besides the fact that it’s a functional language, is that it’s a transformation language, i.e. it works by transforming an input document into an output document. An XSLT file can’t do anything by itself, it must be applied to something.

The Solution

Anyway, enough background. I’ll now go through the code I wrote to solve this problem in XSLT. As I mentioned above, an XSLT file must be applied to an input document, so here’s my (abridged) input document:

<grid>
  <el>
    <x>0</x>
    <y>16</y>
  </el>
  <el>
    <x>0</x>
    <y>17</y>
  </el>
  <el>
    <x>0</x>
    <y>19</y>
  </el>
  <el>
    <x>1</x>
    <y>3</y>
  </el>
  <el>
    <x>1</x>
    <y>4</y>
  </el>
</grid>

It’s a pretty simple document format and required some pretty trivial modifications to the build_graph() code from a previous article to spit out XML instead of a Python set object. Each <el> element represents a location in the set, and is defined by it’s <x> and <y> child elements.

Next is the XSLT itself:

<?xml version="1.0" encoding="UTF-8"?>
<xsl:stylesheet xmlns:xsl="http://www.w3.org/1999/XSL/Transform"
    xmlns:xs="http://www.w3.org/2001/XMLSchema"
    xmlns:exsl="http://exslt.org/common"
    exclude-result-prefixes="xs exsl"
    version="1.0">
    
   <xsl:include href="firstPass.xslt" />
   <xsl:include href="secondPass.xslt" />
    
    <xsl:output method="xml" indent="yes"/>
    
    <xsl:template match="grid">
        <!-- Actually implement the logic -->
        <xsl:variable name="firstPass">
            <xsl:apply-templates select="el" mode="first" />
        </xsl:variable>
        
        <!-- Strip duplicate locations and empty reservoirs -->
        <xsl:variable name="reservoirs">
            <xsl:apply-templates select="exsl:node-set($firstPass)/reservoir" mode="secondPass"/>
        </xsl:variable>
                
        <!-- Generate summary results -->
        <xsl:element name="results">
            <xsl:element name="numberOfReservoirs">
                <xsl:value-of select="count(exsl:node-set($reservoirs)/reservoir)" />
            </xsl:element>
            <xsl:element name="reservoirs">
                <xsl:apply-templates select="exsl:node-set($reservoirs)/reservoir" mode="results" />
            </xsl:element>
        </xsl:element>
    </xsl:template>
    
    <!-- Results wrapper template -->
    <xsl:template match="reservoir" mode="results">
        <xsl:copy>
            <xsl:attribute name="size">
                <xsl:value-of select="count(location)" />
            </xsl:attribute>
            <xsl:copy-of select="location"/>
        </xsl:copy>
    </xsl:template>
    
</xsl:stylesheet>
<?xml version="1.0" encoding="UTF-8"?>
<xsl:stylesheet xmlns:xsl="http://www.w3.org/1999/XSL/Transform"
    xmlns:xs="http://www.w3.org/2001/XMLSchema"
    xmlns:exsl="http://exslt.org/common"
    exclude-result-prefixes="xs exsl"
    version="1.0">
    
    <!-- First time through we need to create the reservoir element -->
    <xsl:template match="el" mode="first">        
        <xsl:element name="reservoir">

                <xsl:element name="location">
                    <xsl:copy-of select="*" />
                </xsl:element>
                
                <xsl:variable name="this_x" select="number(x/text())"/>
                <xsl:variable name="this_y" select="number(y/text())"/>
                
                <!-- Register that we have visited this element before -->
                <xsl:variable name="visited">
                    <xsl:copy-of select="." />
                </xsl:variable>
                
                <!-- We rely on the grid being sorted in increasing x and y. -->
                <!-- The first time through, we will only find a neighbor with increasing or same x -->
                <xsl:apply-templates select="following::el[
                    (x = $this_x and y = $this_y + 1) or
                    (x = $this_x + 1 and y = $this_y - 1) or
                    (x = $this_x + 1 and y = $this_y) or
                    (x = $this_x + 1 and y = $this_y + 1)]" mode="recursive">
                    <xsl:with-param name="visited" select="$visited" />
                </xsl:apply-templates>
            
        </xsl:element>
    </xsl:template>
    
    <!-- Subsequent times through we don't, but we do need to check we haven't been here before -->
    <xsl:template match="el" mode="recursive">
        <xsl:param name="visited" />
        
        <xsl:variable name="currentElement" select="." />
        
        <!-- Check if we have been here before, which stops infinite recursion -->
        <xsl:if test="not(exsl:node-set($visited)/el[x = $currentElement/x][y = $currentElement/y])">
            
            <xsl:variable name="this_x" select="number(x/text())"/>
            <xsl:variable name="this_y" select="number(y/text())"/>
            
            <!-- Add the current location to the list of visited locations -->
            <xsl:variable name="newVisited">
                <xsl:copy-of select="exsl:node-set($visited)/el" />
                <xsl:copy-of select="$currentElement" />
            </xsl:variable>
            
            <xsl:element name="location">
                <xsl:copy-of select="*" />
            </xsl:element>
            
            <!-- Apply this template over all neighbor locations -->
            <!-- Here we might need to go 'up' and 'left', so we need to check negative offsets -->
            <xsl:apply-templates select="../el[
                (x = $this_x - 1 and y = $this_y - 1) or
                (x = $this_x - 1 and y = $this_y) or
                (x = $this_x - 1 and y = $this_y + 1) or
                (x = $this_x and y = $this_y - 1) or
                (x = $this_x and y = $this_y + 1) or
                (x = $this_x + 1 and y = $this_y - 1) or
                (x = $this_x + 1 and y = $this_y) or
                (x = $this_x + 1 and y = $this_y + 1)]" mode="recursive">
                <xsl:with-param name="visited" select="$newVisited" />
            </xsl:apply-templates>
        </xsl:if>
    </xsl:template>
    
</xsl:stylesheet>
<?xml version="1.0" encoding="UTF-8"?>
<xsl:stylesheet xmlns:xsl="http://www.w3.org/1999/XSL/Transform"
    xmlns:xs="http://www.w3.org/2001/XMLSchema"
    exclude-result-prefixes="xs"
    version="1.0">
    
    <!-- We need a separator here, otherwise we can't tell the difference between e.g. 1,20 and 12,0 -->
    <xsl:key name="locations-by-coords" match="location" use="concat(x, ',', y)" />
    
    <!-- Only match reservoirs that have at least one location after stripping duplicates -->
    <xsl:template match="reservoir[count(location[generate-id() = generate-id(key('locations-by-coords', concat(x, ',', y))[1])]) != 0]" mode="secondPass">
        <xsl:copy>
            <!-- Iterate over the unique locations in a reservoir -->
            <xsl:apply-templates select="location[generate-id() = generate-id(key('locations-by-coords', concat(x, ',', y))[1])]" mode="secondPass" />
        </xsl:copy>
    </xsl:template>
    
    <!-- Actually output the location -->
    <xsl:template match="location" mode="secondPass">
        <xsl:copy-of select="." />
    </xsl:template>
    
</xsl:stylesheet>

You can see one of the things I hate about XSLT, which is how wordy and bloated the source is! In case you didn’t notice, it’s in three different files just to avoid hundreds and hundreds of lines of scrolling.

The code is executed in two parts. The first is extremely similar to the Python DFS solution, although because XSLT is a functional language and all variables are immutable, there’s no way of determining if a location has already been visited. There are some optimizations that can be made by ensuring the locations are visited in a well-defined order using the <xsl:sort /> function, but it doesn’t allow the elimination of visiting the same location twice. Consider the following situation:

Even if the list of locations is sorted by increasing x and then increasing y, or even increasing distance from the origin, you still wouldn’t necessarily identify (3,0) as being a part of the well unless you go all the way down to (3,2), and then back up to (3,0). At this point, you may already have started a new well at (3,0) which is a duplicate of the first well. The global ‘visited’ set in the Python implementation solves this problem, but that’s not possible in a functional implementation.

The solution here is to add a second pass through which does two things. The first is that it removes duplicate nodes within the same well, caused when there are multiple possible ways of getting from one active location to another. The second is to remove all reservoirs that are a subset of another reservoir. Through some clever sorting, we can ensure that we will always have the full reservoir before any subsets, which means we can perform this filtering by doing a global filter on unique locations. The code implements Muenchian Grouping2https://en.wikipedia.org/wiki/XSLT/Muenchian_grouping to enable this, which is a neat way of removing duplicate nodes based on arbitrary but definable rules for what constitutes a duplicate within an element.

The final step is then a template that wraps the reservoir results with a summary of how many reservoirs were found in total and adds the size of each reservoir to each reservoir element. An excerpt of the output is given below.

<?xml version="1.0" encoding="UTF-8"?>
<results>
   <numberOfReservoirs>21</numberOfReservoirs>
   <reservoirs>
      <reservoir size="3">
         <location>
            <x>0</x>
            <y>16</y>
         </location>
         <location>
            <x>0</x>
            <y>17</y>
         </location>
         <location>
            <x>1</x>
            <y>17</y>
         </location>
      </reservoir>
      <reservoir size="2">
         <location>
            <x>0</x>
            <y>19</y>
         </location>
         <location>
            <x>1</x>
            <y>20</y>
         </location>
      </reservoir>

Thoughts

First things first, I know this is a terrible solution. It’s inefficient in that it requires multiple passes to generate the right output, and it fails completely at pretty modestly-sized grids (anything above 50×50 takes too long to complete). It might be possible to perform some improvements, but it’s not something I really care to do.

I realize that XSLT was never designed to perform this kind of work, but then I thought to myself: “how you would go about any kind of efficient BFS implementation without being able to rely on a global ‘visited’ tracker to avoid hitting the same location twice?” Is it even possible to implement an O(n) BFS algorithm in a functional language?

The answer seems to be yes! An article on Stack Overflow3https://stackoverflow.com/questions/30464163/functional-breadth-first-search has a solution in Haskell that seems to do exactly what’s required. I have no experience writing in Haskell, but it seems like an extremely interesting approach. If I get time, it’s something I’ll play around with to see how it works, and whether it makes sense for a problem like this!

Until then, back to docker…

The Surveying Problem Part 4: Graph Solution

As I mentioned in my previous post, the recursive solution to the Surveying Problem is pretty performant, but I wanted to see if a graph-based solution would be faster or easier to implement.

I started with the NetworkX package1https://networkx.github.io/, which I have used before in some other projects. In my experience, it’s reasonably quick, easy to install, and is released under a BSD license, which is important when you are writing in a commercial environment. Without further ado, here’s the code:

import networkx as nx

METHOD_NAME = "NetworkX Method"

def get_neighbors(node, graph):
    """ Returns a list of neighbor location objects

     node: A networkX node. The name is a 2-tuple representing the x and y coordinates of an element
     graph: A networkX graph of locations
     """
    x_coord, y_coord = node
    x_offsets = [-1, 0, 1]
    y_offsets = [-1, 0, 1]

    neighbors = list()
    for x_offset in x_offsets:
        for y_offset in y_offsets:
            if x_offset == 0 and y_offset == 0:
                continue

            coord = (x_coord + x_offset, y_coord + y_offset)
            if coord in graph:
                neighbors.append(coord)
    return neighbors


def find_reservoirs(grid):
    """ Uses a graph approach to find how many wells are needed, making the assumption that
    only one well is needed per contiguous field

    grid: Set containing all locations with oil
    """

    locations_graph = nx.Graph()
    locations_graph.add_nodes_from(grid)

    edges_to_create = set()
    for node in locations_graph:
        neighbors = get_neighbors(node, locations_graph)
        _ = [edges_to_create.add((node, neighbor)) for neighbor in neighbors]

    locations_graph.add_edges_from(edges_to_create)

    connected_subgraphs = nx.connected_components(locations_graph)
    wells = [{vertex for vertex in subgraph} for subgraph in connected_subgraphs]

    return wells

Pros: Simplicity

First, the pros of this approach. Firstly, the code is conceptually a lot simpler than the recursive approach:

  1. Create nodes for every location that has oil
  2. Connect nodes together if they are adjacent
  3. Call nx.connected_components() on our graph
    • This generates the set of connected subgraphs
    • i.e. the set of subgraphs where every node is connected to every other node in that subgraph
  4. List comprehension to get the correct output summary

Another pro is that a lot of the algorithmic design has been done by the NetworkX folks, as opposed to by me. As far as an interview question goes that’s obviously a bad thing since the point of the exercise was to test my problem-solving abilities, not to have an encyclopedic knowledge of which libraries are available. However, in a professional environment, I think it’s always better to use a library and let someone else do the thinking for you.

Cons: Performance

As for the cons, performance is significantly worse, almost by an order of magnitude. Here’s a summary of the comparison between the recursive solution and the NetworkX-based graph solution:

Grid SizeTime (Recursive)Time (NetworkX)Time (Recursive / NetworkX)
10×104.603E-05 s1.799E-04 s25.6 %
100×1002.369E-03 s1.634E-02 s14.5 %
1000×10003.933E-01 s2.139E+00 s18.4 %
5000×50001.649E+01 s8.869E+01 s18.5 %

The obvious question to ask is, why? To answer this question, I looked at the source for the nx.connected_components() method.

Investigating the Slowdown

Behind the scenes, NetworkX is doing something very similar to what I implement previously 2https://github.com/networkx/networkx/blob/master/networkx/algorithms/components/connected.py. nx.connected_components() implements a Breadth-First Search3https://en.wikipedia.org/wiki/Breadth-first_search, whereas my recursive code implements a Depth-First Search4https://en.wikipedia.org/wiki/Depth-first_search. A BFS might make sense if we just wanted to find the existence of every well, but since we need to find out every well’s size and complete location, the decision between BFS vs DFS really doesn’t make a difference; the optimal solution is just that we should each location exactly once. The Wikipedia page on connected components5https://en.wikipedia.org/wiki/Connected_component_(graph_theory) states the following:

“It is straightforward to compute the components of a graph in linear time (in terms of the numbers of the vertices and edges of the graph) using either breadth-first search or depth-first search. In either case, a search that begins at some particular vertex v will find the entire component containing v (and no more) before returning. To find all the components of a graph, loop through its vertices, starting a new breadth-first or depth-first search whenever the loop reaches a vertex that has not already been included in a previously found component. Hopcroft & Tarjan (1973) describe essentially this algorithm, and state that at that point it was “well known”.”

https://en.wikipedia.org/wiki/Connected_component_(graph_theory)

In other words, my solution that I gave during my interview was literally the textbook solution to the problem!

But if the algorithms are effectively equivalent, then why the performance difference? I strongly suspect the reason is because of the additional overhead in the NetworkX approach, probably related to the number of additional function calls. Python is notoriously bad with function call latency https://ilovesymposia.com/2015/12/10/the-cost-of-a-python-function-call/, and the poor performance of NetworkX is well documented6https://graph-tool.skewed.de/performance.

Moving to a library that does more of the heavy lifting in a more performant language should enable some significant performance improvements. I’m going to take a look at graph-tool and see how that goes. I have been compiling it while writing this article and it’s still going, so it might be a while before I can start and the code and write the next article. Don’t hold your breath!

Next Steps

  • Implement the code in graph-tools
    • Pre-requisite: Compile graph-tools!
  • Try and get the NetworkX implementation to work on a grid of 10,000×10,000 elements. I did try this for the table above, but my system ran out of RAM. I might have to switch to a disk-based copy of the grid to generate the graph as I do. This will have a pretty significant performance hit, but might be the only way to get the memory usage down

The Surveying Problem Part 3: The Recursive Solution

As I mentioned in my previous post, I have been working on some possible solutions to the following problem:

You are working for a company that is responsible for processing survey data for an Oil and Gas company. You are given data that summarizes whether oil was found at a certain grid reference. The assumption is that if oil is found in contiguous grid references (or grid elements), then they form a single oil reservoir. The objective is to write a program that will determine the number and locations of the reservoirs.

The question was posed to me in an interview context, so I only had access to standard Python libraries. As a result, the solution I came up with used a recursive algorithm, here’s the code:

""" Generate a grid of elements with a random layout of elements that contain oil
build_grid: Build the grid with an optional random seed
get_neighbors: Find the number of active neighbor elements for a given element
"""

import random
import itertools

THRESHOLD = 0.85

def get_neighbors(this_element, grid):
    """ Returns a list of neighbor location objects

     this_element: A 2-tuple representing the x and y coordinates of an element
     grid: A dictionary containing all elements and their state
     """
    x_coord, y_coord = this_element
    x_offsets = [-1, 0, 1]
    y_offsets = [-1, 0, 1]

    neighbors = list()
    for x_offset in x_offsets:
        for y_offset in y_offsets:
            if x_offset == 0 and y_offset == 0:
                continue

            coord = (x_coord + x_offset, y_coord + y_offset)
            if coord in grid:
                neighbors.append(coord)
    return neighbors


def build_grid(size, seed=None):
    """ Build a square grid of elements, where each element may or may not contain oil

     size: The number of elements along one edge of the grid
     seed: Random seed to be used to generate the grid
     """
    random.seed(seed)

    initial_grid = set()
    for location in itertools.product(range(0, size), repeat=2):
        if random.random() > THRESHOLD:
            initial_grid.add(location)

    grid = set()
    # Cluster the grid. If an active element is not isolated,
    # or if an inactive element has at least 4 active neighbors
    for location in itertools.product(range(0, size), repeat=2):
        state = location in initial_grid
        sites_nearby = get_neighbors(location, initial_grid)
        neighbor_count = len(sites_nearby)
        if (state and neighbor_count != 0) or neighbor_count >= 4:
            grid.add(location)

    return grid

It’s worth saying that whilst I came up with the broad algorithm during the interview, the version above contains a few improvements over the original, both to the structure of the code and the performance.

One example: the code in the previous post had the locations object as a dict(), where each key was a 2-tuple representing the x and y coordinates, and each value was either True or False. This was fine, but when trying with large grids (1000 x 1000+) I was using a not-insignificant amount of memory, so I switched to using a list containing the True 2-tuples only.

Performance Optimization

However, imagine my horror when I realized that while I reduced my RAM usage, I slowed down the performance on a 100×100 grid by a factor of over 1000! I must admit I had to resort to the Profiler in PyCharm to point me in the right direction, which pointed me firmly in the direction of my get_neighbors method.

The Profiler wasn’t kind enough to tell me that it was specifically my coord in all_locations call that was causing the slowdown, but the get_neighbors function doesn’t do a whole lot, so I quickly realized this must be the cause. After a quick Google, I came up with the Python TimeComplexity page1TimeComplexity. As long as the keys stored in a dict have sensible hashing functions (which my 2-tuples do), an in call with a dict is O(1), compared to O(n) for a list.

I moved away from a dict because I had no values, so instead of switching to a list, I should have switched to a set. A set is basically a dict without values (it turns out in CPython that’s exactly how it’s implemented), and is an extremely useful type, especially when you want to perform set-type operations. Here’s the call graph using a set instead of a list, as you can see the bottleneck in get_neighbors is gone.

In case you are wondering, the other 60% is in _find_and_load, so not much we can do there. This was for a grid size of around 100×100 though, and the _find_and_load overhead decreases as expected as the grid size is increased.

That’s pretty much it in terms of the recursive approach. Performance is not bad, here are some results for some example runs:

Grid SizeTime# of ReservoirsTime / ReservoirTime / ElementReservoir Frequency
10×101.528E-05 s27.64E-06 s1.528E-07 s0.02
100×1003.417E-03 s3061.117E-05 s3.417E-07 s0.0306
1000×10003.657E-01 s28,4251.287E-05 s3.657E-07 s0.0284
10000×100008.298E+01 s2,841,3602.920E-05 s8.298E-07 s0.0284

The Time/Element and Time/Reservoir results are pretty flat, which is probably what I would expect. The average number of reservoirs per element (reservoir frequency) is also reasonably flat, which trends towards 0.0284. It might be interesting to look into this a bit more, especially how it depends on how the grid itself was made and the clustering rules. #TODO

Memory Use

One definite negative to the code above is that even with a shift to storing only active elements in a set, it’s extremely memory intensive. The final run on a 10,000×10,000 grid was using around 4 Gb of memory at peak, which I think is probably due to two factors:

Firstly, I can’t use a yield-type idiom because I’m bouncing around looking at elements before and after the one I am processing. Also, during the clustering process, both the original and clustered grids have to be in memory at once, so I can examine the current state while I build the new state.

Secondly, native Python types are pretty bloated, memory-wise2https://stackoverflow.com/questions/1331471/in-memory-size-of-a-python-structure. An int is 24 bytes and the overhead for a tuple is 56 bytes, which means each element needs 104 bytes. The original grid is 15% active, which means for a grid of 10,000 x 10,000 elements 1.5 Gb of RAM is required. Assuming the clustered grid is around 10% active, that’s 2.5 Gb of memory to store both the original and clustered grid.

Next Steps

So, pros and cons to this approach, and definitely worth looking at more. However, my plan is to dig into the graph-based approach to see how the code complexity and performance changes.

The Surveying Problem Part 2: Building a Grid

If I’m going to solve this problem, the first thing I need is a grid. In my interview, the grid itself didn’t actually exist; it was all hypothetical: “assuming a grid in a format that you can dictate, write an algorithm that would solve this problem”. If I am actually going to solve the problem, I need a grid with which to work!

Building a Grid

I assumed a square grid at all times, needing only one parameter to define the size. I then wrote a pretty basic function that would return a grid in the form of a dict, where the key is a 2-tuple representing the x and y coordinate of the element, and the value is either True or False, depending on whether there is oil there or not.

import random
import itertools

THRESHOLD = 0.85

def build_grid(size, seed=None):
   """ Build a square grid of elements, where each element may or may not contain oil

    size: The number of elements along one edge of the grid
    seed: Random seed to be used to generate the grid
    """
    random.seed(seed)

    grid = dict()
    for x_coord, y_coord in itertools.product(range(0, size), repeat=2):
        grid[(x_coord, y_coord)] = random.random() > THRESHOLD

    return grid

I somewhat arbitrarily picked 0.85 as the threshold for determining whether an element contained oil or not, which seemed to give reasonable results. However, the grids that were generated from this approach often contained a lot of isolated ‘oil’ (or active) elements. I was keen to ensure I had some degree of clustering, which would ensure that my code would only return a single reservoir for adjacent active elements. Here’s an example grid that was generated with the code above:

Clustering the Grid

There are a few active elements that are adjacent to one or more other active elements, but a good number are isolated. I wanted to make sure I had something that was a bit more representative of ‘reality’, or at least my idea of reality (I’m not a geologist!), where the active elements would always be clustered together. To achieve this, I modified the build_grid() code perform an additional round of processing to cluster the grid:

""" Generate a grid of elements with a random layout of elements that contain oil
build_grid: Build the grid with an optional random seed
get_neighbors: Find the number of active neighbor elements for a given element
"""

import random
import itertools

THRESHOLD = 0.85

def get_neighbors(this_element, grid):
    """ Returns a list of neighbor location objects

     this_element: A 2-tuple representing the x and y coordinates of an element
     grid: A dictionary containing all elements and their state
     """
    x_coord, y_coord = this_element
    x_offsets = [-1, 0, 1]
    y_offsets = [-1, 0, 1]

    neighbors = list()
    for x_offset in x_offsets:
        for y_offset in y_offsets:
            if x_offset == 0 and y_offset == 0:
                continue

            coord = (x_coord + x_offset, y_coord + y_offset)
            if grid.get(coord, False):
                neighbors.append(coord)
    return neighbors


def build_grid(size, seed=None):
    """ Build a square grid of elements, where each element may or may not contain oil

     size: The number of elements along one edge of the grid
     seed: Random seed to be used to generate the grid
     """
    random.seed(seed)

    initial_grid = dict()
    for x_coord, y_coord in itertools.product(range(0, size), repeat=2):
        initial_grid[(x_coord, y_coord)] = random.random() > THRESHOLD

    grid = set()
    # Cluster the grid. If an active element is not isolated,
    # or if an inactive element has at least 4 active neighbors
    for location, state in initial_grid.items():
        sites_nearby = get_neighbors(location, initial_grid)
        neighbor_count = len(sites_nearby)
        if (state and neighbor_count != 0) or neighbor_count >= 4:
            grid.add(location)

    return grid

This method implements something inspired by Conway’s Game of Life1https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life, by getting the number of adjacent active elements for each element in turn and applying the following rules:

  • If an element is active and has at least 1 active neighbor, it remains active
  • If an element is inactive and has more than 4 active neighbors, it becomes active
  • Any other element is inactive

This tends to produce a more clustered grid, which seems more representative of reality:

Next Steps

And that’s it! I now have some code that generates procedurally generated grids that I can use as inputs to my solution. A couple of notes on this though, before I move on:

  1. I used two different seeds for the two figures in this article, so the active elements don’t line up. Sorry if that causes any confusion!
  2. The code for drawing the output will be covered in a future article

Next step, solve the problem!