There’s no walking the show floor at RSA Conference or Black Hat these days without vendors bombarding you with messages like “machine learning” and “artificial intelligence”. But few of them will get around to explaining what it is, how it works and why you should care.
At Sophos, we’ve made big investments in data science and machine learning, including acquiring machine learning company Invincea and establishing a team of leading data scientists focused on infusing machine learning into the core of our products.
Our approach to machine learning has always been based on science, transparency and validation.
Instead of describing machine learning like pixie dust to be spread on products, we’ll use this forum to describe the nuts, bolts and challenges in machine learning and how we approach it.
It’s just mathematics, right?
One of the hardest things when evaluating machine learning products is how to find out what’s under the hood and why.
To help, we’ve come up with a list of five questions that cut to the heart of how well a particular approach performs, regardless of which algorithm it uses.
Q1. That’s an impressive detection rate, but what’s the false positive rate when you turn up detection that high?
Just quoting detection rates for machine learning algorithms is not enough.
After all, you can trivially achieve a 100% detection rate by simply convicting as malicious every file that you scan. But your false positive rate – that’s when you wrongly prevent legitimate files from being used – will be close to 100% as well. You won’t be able to use your computer at all.
In other words, the false positive rate for the detection algorithm is at least as important as its true detection rate. Ignoring the false positive rate means constantly chasing phantoms on the network or interrupting your users’ work.
In machine learning, this is represented by a graph of what’s called the receiver operating characteristic curve (ROC curve) that shows how true detection rate is traded off against the false positive rate.
If the vendor can’t, or won’t, show you an ROC curve, you can’t even begin to guess how low a detection rate you will need to tolerate – in other words, how much malware the product will let through – before it becomes tolerable to use.
What to ask: Can I see your ROC curves, both now and from the past?
Q2. How often does your model need updating, and how much does your model’s accuracy drop off between updates?
The real power of machine learning is that if it is properly trained, it can very reliably detect threats that it hasn’t seen before. That makes it especially effective at blocking new threats proactively.
Additionally, a good machine learning model will not only perform well on today’s threats, but will also display a characteristic known as “slow aging”, meaning that it will continue picking off new threats for a long time without needing an update.
In short, a good machine learning model will retain an acceptable balance of detection rate versus false positive rate for months, rather than just for weeks or days.
What to ask: Can I see today’s ROC curve for the update you published six months ago?
Q3. Does your machine learning algorithm make decisions in real-time?
If scanning for malware takes longer than the malware takes to do its dirty work, you have detection but not prevention. In other words, you’ll only ever find out about an attack after the fact.
Some forms of machine learning are used to sift through data after an attack, trying to find the proverbial needle in a haystack. This is useful for figuring out what happened so you can prevent it happening again, but if you are in the business of stopping attacks before they succeed, you need an algorithm that operates in milliseconds, not in minutes or seconds.
Generally speaking, a machine learning solution with a dataset that won’t fit on your endpoints will require a cloud connection to work at all, and will be both slow and unreliable. In fact, the dataset should be sufficiently compact that it can be held in memory, thus avoiding the need to keep reading detection data off disk as each file is scanned.
What to ask: Is this real time? If so, how long does a decision take? What happens to performance and accuracy if the computer is offline?
Q4. What is your training set?
An algorithm’s performance in practice is based on the data on which it is trained.
The old adage of garbage in, garbage out applies here. If the data is academic, old or doesn’t realistically represent real world files, the algorithm is unlikely to perform reliably when faced with files outside the lab.
What to ask: Where does the training data come from? What makes it realistic? How much data is there? How do you keep the training sets up to date?
Q5. How well does does your machine learning system scale?
Given that new threats (and new clean files) come out all the time, you need to collect new but representative data all the time, at internet scale.
Collecting ever-increasing amounts of relevant data is hard enough on its own, but you also need to be able to scale up the speed at which you can re-train and re-test your model, or you’ll end up with updates taking longer and longer to produce.
However, the dataset you extract to run your machine learning model needs to stay constant in size, even if your training sets grow exponentially. Otherwise your runtime performance will get worse and worse, just to keep afloat in detection rates.
In other words, “scale” has several variables, including how quickly you can ramp up the size and relevance of your training set, whether you can keep on top of the time needed for training and testing as the training data increases, and whether you can keep the size of your runtime dataset down without sacrificing the detection rate (or increasing the false positive rate).
What to ask: Can I see historical statistics showing how well your training and runtime performance has scaled?
More than just mathematics
To summarise your ideal machine learning approach will:
- Give high detection rates and low false positives on known and unknown attacks, with a published ROC curve.
- Be trained on a robust training set that is representative of real world threats.
- Continue to deliver high performance for months after each update.
- Provide real time performance (threat blocking) without consuming large amounts of memory.
- Scale reliably, without using more memory or losing performance, even as the training set increases.
Together these may be seem like a big ask, but they are exactly what is needed for any production machine learning approach, and also what we are focused on building and releasing at Sophos.
Next time you talk to a company that claims to use machine learning, be sure to ask these questions.
You’ll know what to say when they respond “It’s just mathematics”.