How To Build a Fraud Detection Model in Java Using Deeplearning4J
Rate this tutorial
In this tutorial, we’ll learn how to combine the power of neural networks and scalable data management to build a real-world fraud detection system. AI isn’t just for JavaScript and Python anymore. Using Deeplearning4J, we’ll train a neural network in Java, and with MongoDB, we’ll manage and store transaction data efficiently.
Given a whole host of data on customer transactions, we are going to teach or model how to spot the frauds. Whether it’s a suspicious amount, odd location, or peculiar time, a lot of variables are at play, and it’s important to get it right! We’ll preprocess our transaction data and train a neural network to integrate MongoDB for scalable data storage. By the end, we’ll have a fully functional application capable of identifying fraudulent transactions.
We’ll create a system that:
- Loads and preprocesses transaction data from a CSV file.
- Trains a neural network to classify transactions as fraudulent or non-fraudulent.
- Stores and retrieves transaction data in MongoDB, ensuring scalability and persistence.
Java often flies under the radar in the AI world, but Java has some undeniable strengths that make it a solid choice for building AI systems, especially when we need to go beyond experiments and into our scalable production environments.
- Integration with enterprise systems: Java is a cornerstone in enterprise software. When we want to embed AI into existing systems, Java’s prevalence makes it a natural fit.
- Performance and scalability: With its multithreading and JVM optimizations, Java handles high-performance and distributed applications with elegance and grace (hyperbole)—perfect for real-world AI workloads.
By using Java for AI, we’re not just building something cool—we’re creating something robust, scalable, and ready to handle the demands of the real production world. Enough yapping about why it is cool, let’s actually start building something.
- Data preprocessing: How to clean and prepare transaction data for training
- Training the neural network: Using Deeplearning4J to train a classification model
- Integrating MongoDB: Managing transaction data in a scalable and efficient database
- Real-time interaction: Building an interactive CLI for testing our model and running predictions. This can just as easily be an API, or any other way you plan on interacting with your application.
To follow along with this tutorial, you should have the following:
- Get started with MongoDB Atlas for free! MongoDB offers a free-forever Atlas cluster with the M0 tier.
Deeplearning4J is a Java-based deep learning library designed for developers who want to build production-ready AI systems without leaving the JVM ecosystem. Unlike some of the more experimental tools in the AI world, Deeplearning4J is much more mature and focuses on real-world use cases, offering features that integrate with enterprise applications.
At the risk of sounding like an ad campaign for it, here’s what makes it stand out:
- Native Java support: It’s built for Java developers, so we can use familiar tools and workflows to create deep learning models (like MongoDB).
- Scalability: Deeplearning4J supports distributed training right out of the box, making it ideal for large datasets and high-performance applications.
- Flexibility: Whether we’re working on classification, regression, or more complex architectures, the library provides the building blocks to customize and optimize our models.
- Integration: Deeplearning4J plays nicely with other Java tools, including Apache Spark for big data processing and Hadoop for distributed systems.
Deeplearning4j stands out by being about more than just experimenting with AI in Java—we’re building models that can be deployed directly into production environments, alongside the rest of our Java applications.
The dataset we’re working with is Credit Card Fraud available on Kaggle. It is quite a widely recognized credit card fraud detection dataset, containing anonymized transactional data. It provides a solid foundation for training our machine learning model to detect fraudulent transactions.
This dataset has the following key characteristics:
- Size and composition:
- Total transactions: 284,807
- Fraudulent transactions: 492 (0.172% of the data)
- The dataset is highly imbalanced, as fraudulent transactions represent a very small percentage of the total. This reflects the real-world challenge where fraud is rare but critical to detect.
- Features:
- Time: The seconds elapsed between a transaction and the first transaction in the dataset.
- Amount: The monetary value of the transaction.
- V1 to V28: Principal components resulting from PCA (Principal Component Analysis), used to anonymize sensitive data.
- Class: The target variable, where
0
indicates a legitimate transaction and1
indicates fraud.
The extreme class imbalance (fraudulent transactions make up only 0.172% of the dataset) can complicate machine learning:
- Models can become biased toward the majority class (
0
), leading to poor performance on the minority class (1
).
For this tutorial, we won’t focus on addressing the imbalance. Instead, we’ll train and evaluate the model on the original dataset, treating it as a learning exercise. However, techniques like resampling, weighted loss functions, or cost-sensitive learning would be essential to deal with an imbalanced dataset.
To get the data ready for our model, we’ll follow these steps:
- Selecting features:
- We’ll start with the Amount feature as our input.
- The Time feature and PCA components (
V1
toV28
) will be excluded for simplicity, but they can be explored later to enhance the model. I run a test with these later, and you'll be able to see why I decided to exclude them for our simple implementation.
- Target variable:
- The Class column serves as our target variable, with
0
for legitimate transactions and1
for fraud.
- Normalizing data:
- Neural networks perform better with consistent input scaling, so we’ll normalize the Amount column to a range of 0 to 1.
- Splitting the data:
- We’ll divide the dataset into a training set (80%) and a test set (20%) to evaluate how well our model generalizes to unseen data.
By keeping the preprocessing straightforward, we can focus on understanding the mechanics of building and training a neural network. While we’re not addressing the class imbalance directly in this tutorial, it remains an absolutely critical consideration for real-world applications. I will reiterate this point a couple times throughout this tutorial, but that’s because it is probably the most influential factor on our results in this model. Let’s move on to the actual preprocessing code!
First, we need to add our dependencies to our POM.
1 <dependencies> 2 <dependency> 3 <groupId>org.deeplearning4j</groupId> 4 <artifactId>deeplearning4j-core</artifactId> 5 <version>1.0.0-M2.1</version> 6 </dependency> 7 <dependency> 8 <groupId>org.nd4j</groupId> 9 <artifactId>nd4j-native-platform</artifactId> 10 <version>1.0.0-M2.1</version> 11 </dependency> 12 <dependency> 13 <groupId>org.mongodb</groupId> 14 <artifactId>mongodb-driver-sync</artifactId> 15 <version>5.2.0</version> 16 </dependency> 17 <dependency> 18 <groupId>org.slf4j</groupId> 19 <artifactId>slf4j-simple</artifactId> 20 <version>2.0.16</version> 21 </dependency> 22 23 </dependencies>
Here, we are importing a few dependencies for Deeplearning4j, as well as the MongoDB Java driver. We also have slf4j just for logging, and nd4j. Nd4j provides convenience methods for the creation of arrays from Java float and double arrays.
Create a
MongoDBConnector
class. This will be how we establish our connection to our MongoDB database, where we will store our data.For the
URI
, add your connection string for your database. Change your database and collection name to whatever makes sense, but I'm using fraudDection
and transactions
, respectively, for this example.1 package com.mongodb; 2 3 import com.mongodb.client.MongoClient; 4 import com.mongodb.client.MongoClients; 5 import com.mongodb.client.MongoCollection; 6 import com.mongodb.client.MongoDatabase; 7 import org.bson.Document; 8 9 import java.util.concurrent.TimeUnit; 10 11 public class MongoDBConnector { 12 private static final String URI = "YOUR-CONNECTION-STRING"; 13 private static final String DATABASE_NAME = "fraudDetection"; 14 private static final String COLLECTION_NAME = "transactions"; 15 16 private final MongoClient mongoClient; 17 18 public MongoDBConnector() { 19 MongoClientSettings settings = MongoClientSettings.builder() 20 .applyConnectionString(new ConnectionString(URI)) 21 .applyToSocketSettings(builder -> 22 builder.connectTimeout(30, TimeUnit.SECONDS) 23 .readTimeout(30, TimeUnit.SECONDS)) 24 .build(); 25 26 mongoClient = MongoClients.create(settings); 27 } 28 29 public MongoCollection<Document> getCollection() { 30 return mongoClient.getDatabase(DATABASE_NAME).getCollection(COLLECTION_NAME); 31 } 32 33 }
We are also applying specific timeout settings to our
MongoClient
. With us uploading a large amount of data to our db in large chunks, this will help us work around any timeout exceptions caused by networking issues.To work with our transaction data, we will create a
Transaction
class. We are going to simplify it for this demo, and just use the amount and whether it has been marked as fraudulent or not.This will significantly affect the reliability of our model, so feel free to add more features for classifying as you want. AI is a complex world, and training models is a real skill. Different features will be more reliable for making fraudulent predictions. Think of large transactions on the far side of the globe, or during odd hours for that particular user.
1 package com.mongodb; 2 3 import org.bson.Document; 4 5 public class Transaction { 6 private double amount; 7 private boolean isFraudulent; 8 9 public Transaction(double amount, boolean isFraudulent) { 10 this.amount = amount; 11 this.isFraudulent = isFraudulent; 12 } 13 14 public double getAmount() { return amount; } 15 public boolean isFraudulent() { return isFraudulent; } 16 17 public Document toDocument() { 18 Document doc = new Document(); 19 doc.append("amount", amount); 20 doc.append("isFraudulent", isFraudulent); 21 return doc; 22 } 23 }
This will provide a simple structure for our data and allow us to convert it into a MongoDB-friendly format with the
toDocument
method.Next, we need a
TransactionRepository
class. This will be where we encapsulate the operations for saving and fetching our transactions from the database. This abstraction lets us keep our data access logic organized and reusable.1 package com.mongodb; 2 3 import com.mongodb.client.MongoCollection; 4 import com.mongodb.client.model.BulkWriteOptions; 5 import com.mongodb.client.model.WriteModel; 6 import org.bson.Document; 7 8 import java.util.ArrayList; 9 import java.util.List; 10 11 public class TransactionRepository { 12 private final MongoCollection<Document> collection; 13 14 public TransactionRepository(MongoDBConnector connector) { 15 this.collection = connector.getCollection(); 16 } 17 18 public void bulkSaveTransactions(List<WriteModel<Document>> transactions) { 19 if (!transactions.isEmpty()) { 20 try { 21 BulkWriteOptions options = new BulkWriteOptions().ordered(true); 22 collection.bulkWrite(transactions, options); // Perform bulk write operation 23 } catch (Exception e) { 24 e.printStackTrace(); 25 } 26 } 27 } 28 29 public List<Transaction> getAllTransactions() { 30 List<Transaction> transactions = new ArrayList<>(); 31 for (Document doc : collection.find()) { 32 double amount = doc.getDouble("amount"); 33 boolean isFraudulent = doc.getBoolean("isFraudulent"); 34 transactions.add(new Transaction(amount, isFraudulent)); 35 } 36 return transactions; 37 } 38 39 }
This repository pattern makes it easy to add more database operations, such as filtering our transactions or performing more advanced queries.
Effective preprocessing of the data is an essential step in machine learning. Here, we are going to load the data from the
creditcard.csv
file, which should be placed in the src/main/resources
directory. Read in the fields we want to focus on in our model training, and format it for our transaction model.1 package com.mongodb; 2 3 import com.mongodb.client.model.WriteModel; 4 import com.mongodb.client.model.InsertOneModel; 5 import org.bson.Document; 6 7 import java.io.BufferedReader; 8 import java.io.FileReader; 9 import java.io.IOException; 10 import java.util.ArrayList; 11 import java.util.List; 12 import java.util.stream.Collectors; 13 14 public class DataPreprocessor { 15 16 TransactionRepository transactionRepository; 17 18 private static final int BATCH_SIZE = 500; // Define batch size for bulk writes 19 private static final int DOCUMENT_LIMIT = 250000; // Maximum documents to insert 20 21 private int documentCount = 0; // Counter for total inserted documents 22 23 public DataPreprocessor(TransactionRepository transactionRepository) { 24 this.transactionRepository = transactionRepository; 25 } 26 27 public void loadData(String filePath) throws IOException { 28 try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) { 29 // Skip the header by reading the first line 30 reader.readLine(); 31 32 List<String> batch = new ArrayList<>(); 33 String line; 34 35 // Read the file line-by-line 36 while ((line = reader.readLine()) != null) { 37 if (documentCount >= DOCUMENT_LIMIT) { 38 System.out.println("Reached the document limit of " + DOCUMENT_LIMIT + ". Stopping data load."); 39 break; // Stop processing when the limit is reached 40 } 41 42 batch.add(line); 43 documentCount++; 44 45 // When batch size is reached, process it 46 if (batch.size() == BATCH_SIZE) { 47 processBatch(batch); 48 batch.clear(); // Clear the batch for the next set of lines 49 } 50 } 51 52 // Process any remaining lines 53 if (!batch.isEmpty()) { 54 processBatch(batch); 55 } 56 } 57 } 58 59 private void processBatch(List<String> batch) { 60 List<WriteModel<Document>> bulkOperations = batch.stream() 61 .map(line -> { 62 String[] fields = line.split(","); 63 double amount = Double.parseDouble(fields[29]); // Adjust index as needed 64 boolean isFraudulent = "1".equals(fields[30]); 65 Transaction transaction = new Transaction(amount, isFraudulent); 66 return new InsertOneModel<>(transaction.toDocument()); 67 }) 68 .collect(Collectors.toList()); 69 transactionRepository.bulkSaveTransactions(bulkOperations); 70 } 71 72 }
To handle the large datasets efficiently, we are using MongoDB’s bulkWrite operation. This reduces the number of network round trips from our application to our MongoDB instance which increases the performance of our application.
We are also limiting the number of documents added to the database to 250,000. This is because we are using the MongoDB M0 tier cluster. It is limited to 500MB and we do not want to exceed this.
To detect the fraudulent transactions, we'll use Deeplearing4J to build our neural network. We'll do this in a
FraudDetectionModel
class.- Input layer (implicit):
Every neural network begins with an input layer that receives the data. In Deeplearning4J, this is handled automatically by the first layer. ThenIn(NUM_INPUT_FEATURES)
in the firstDenseLayer
specifies how many input features the model will take. For now, we’re using only theAmount
feature, but you can easily expand this. - Dense layer (hidden layer):
This is the core of our model where learning happens. It consists of 10 neurons and uses the ReLU (Rectified Linear Unit) activation function, which introduces non-linearity to help the model learn complex patterns in the data. - Output layer:
The output layer predicts whether a transaction is fraudulent or not. It has two neurons—one for each class (fraudulent or legitimate)—and uses the Softmax activation function to generate probabilities for each class. The class with the highest probability becomes the model’s prediction. - Loss function and optimizer:
- The model uses the Negative Log Likelihood loss function, ideal for classification problems.
- It’s optimized using the Adam optimizer with a learning rate of
0.001
, which adapts the learning rate throughout training for better convergence. - Xavier weight initialization is used to keep the starting weights balanced, preventing issues like vanishing or exploding gradients.
- Training feedback:
We added aScoreIterationListener
that outputs the model’s training progress every 10 iterations, giving us insight into how well the model is learning.
If none of what I just said makes any sense, don't worry—it didn't for me when I started either. I've shown you how to implement it below, so you can see it in action. AI is a large field, and a lot of brilliant people have done amazing work to make it more accessible. Check out MongoDB's Developer Center to learn more about AI.
1 package com.mongodb; 2 3 import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 4 import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 5 import org.deeplearning4j.nn.conf.layers.DenseLayer; 6 import org.deeplearning4j.nn.conf.layers.OutputLayer; 7 import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; 8 import org.deeplearning4j.nn.weights.WeightInit; 9 import org.deeplearning4j.optimize.listeners.ScoreIterationListener; 10 import org.nd4j.evaluation.classification.Evaluation; 11 import org.nd4j.linalg.activations.Activation; 12 import org.nd4j.linalg.api.ndarray.INDArray; 13 import org.nd4j.linalg.dataset.DataSet; 14 import org.nd4j.linalg.factory.Nd4j; 15 import org.nd4j.linalg.learning.config.Adam; 16 import org.nd4j.linalg.lossfunctions.LossFunctions; 17 18 import java.util.Collections; 19 import java.util.List; 20 21 public class FraudDetectionModel { 22 private MultiLayerNetwork model; 23 24 private static final int NUM_INPUT_FEATURES = 1; 25 private static final int NUM_CLASSES = 2; 26 27 public FraudDetectionModel() { 28 initializeModel(); 29 } 30 31 private void initializeModel() { 32 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 33 .seed(123) 34 .weightInit(WeightInit.XAVIER) 35 .updater(new Adam(0.001)) 36 .list() 37 .layer(0, new DenseLayer.Builder() 38 .nIn(NUM_INPUT_FEATURES) 39 .nOut(10) 40 .activation(Activation.RELU) 41 .build()) 42 .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) 43 .nIn(10) 44 .nOut(NUM_CLASSES) 45 .activation(Activation.SOFTMAX) 46 .build()) 47 .build(); 48 49 model = new MultiLayerNetwork(conf); 50 model.init(); 51 model.setListeners(new ScoreIterationListener(10)); 52 } 53 54 }
So what does all this mean? Well, let’s explore the components of this class and what each part does.
1 private MultiLayerNetwork model; 2 private static final int NUM_INPUT_FEATURES = 1; 3 private static final int NUM_CLASSES = 2;
model
: The neural network instance that will be trainedNUM_INPUT_FEATURES
: The number of input features—currently set to 1 because only the transaction amount is usedNUM_CLASSES
: Set to 2, representing two possible outcomes: legitimate (0) or fraudulent (1)
1 private void initializeModel() { 2 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 3 .seed(123) 4 .weightInit(WeightInit.XAVIER) 5 .updater(new Adam(0.001)) 6 .list()
NeuralNetConfiguration.Builder()
: Used to define the network architecture.seed(123)
: Sets a random seed for reproducibility.weightInit(WeightInit.XAVIER)
: Initializes the network weights using Xavier initialization, which balances the variance of weights, helping the network train efficiently.updater(new Adam(0.001))
: Uses the Adam optimizer with a learning rate of 0.001 for adaptive learning
1 .layer(0, new DenseLayer.Builder() 2 .nIn(NUM_INPUT_FEATURES) 3 .nOut(10) 4 .activation(Activation.RELU) 5 .build())
- Layer 0: The first hidden layer
nIn(NUM_INPUT_FEATURES)
: Number of inputs, currently 1 (transaction amount)nOut(10)
: Number of neurons in this layer (10 neurons)Activation.RELU
: ReLU activation introduces non-linearity, allowing the model to learn complex relationships
1 .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) 2 .nIn(10) 3 .nOut(NUM_CLASSES) 4 .activation(Activation.SOFTMAX) 5 .build())
- Layer 1: The output layer
LossFunction.NEGATIVELOGLIKELIHOOD
: A loss function suitable for classification tasksnIn(10)
: Takes the 10 outputs from the hidden layernOut(NUM_CLASSES)
: Outputs two values representing the probabilities of each class (fraudulent or legitimate)Activation.SOFTMAX
: Converts outputs into class probabilities that sum up to 1.
1 model = new MultiLayerNetwork(conf); 2 model.init(); 3 model.setListeners(new ScoreIterationListener(10));
model = new MultiLayerNetwork(conf);
: Creates the neural network using the configurationmodel.init();
: Initializes the network’s parameterssetListeners(new ScoreIterationListener(10));
: Logs the model's score (error) every 10 iterations to track training progress
Next, we'll add the method
prepareTrainingData
to our class. The network expects data to be in the form of features (inputs) and labels (targets).1 private DataSet prepareTrainingData(List<Transaction> transactions) { 2 int numTransactions = transactions.size(); 3 INDArray features = Nd4j.create(numTransactions, NUM_INPUT_FEATURES); 4 INDArray labels = Nd4j.create(numTransactions, NUM_CLASSES); 5 DataSet dataSet; 6 7 for (int i = 0; i < numTransactions; i++) { 8 Transaction transaction = transactions.get(i); 9 10 // Use transaction amount as feature 11 features.putScalar(new int[]{i, 0}, transaction.getAmount()); 12 13 // One-hot encoding for labels 14 if (transaction.isFraudulent()) { 15 labels.putScalar(new int[]{i, 1}, 1.0); 16 labels.putScalar(new int[]{i, 0}, 0.0); 17 } else { 18 labels.putScalar(new int[]{i, 0}, 1.0); 19 labels.putScalar(new int[]{i, 1}, 0.0); 20 } 21 } 22 23 return dataSet = new DataSet(features, labels); 24 }
We train the model with a simple loop over multiple epochs. An epoch is a term in machine learning which refers to a complete pass of the complete dataset through the learning algorithm.
1 public void trainModel(List<Transaction> transactions) { 2 // Shuffle the data to ensure random distribution 3 Collections.shuffle(transactions); 4 5 // Prepare training data 6 DataSet dataSet = prepareTrainingData(transactions);; 7 8 // Train the model 9 for (int epoch = 0; epoch < 100; epoch++) { 10 model.fit(dataSet); 11 } 12 13 System.out.println("Model trained successfully."); 14 }
After the training, we'll add a method
evaluateModel
to measure the models performance. This is crucial in AI model training, in order to refine our methodology for how we want to implement the neural network.1 public void evaluateModel(List<Transaction> transactions) { 2 // Split data into train and test sets 3 int trainSize = (int)(transactions.size() * 0.8); 4 List<Transaction> testSet = transactions.subList(trainSize, transactions.size()); 5 6 // Prepare test data 7 DataSet testData = prepareTrainingData(testSet); 8 9 // Perform evaluation 10 Evaluation evaluation = new Evaluation(NUM_CLASSES); 11 INDArray predicted = model.output(testData.getFeatures()); 12 evaluation.eval(testData.getLabels(), predicted); 13 14 // Print evaluation statistics 15 System.out.println(evaluation.stats()); 16 }
We'll see this in our main application later, and we'll learn how to interpret our results.
Lastly, we'll add a method to predict fraud for a singular transaction.
1 public boolean predictFraud(Transaction transaction) { 2 // Convert transaction to INDArray 3 INDArray input = Nd4j.create(new double[][]{{transaction.getAmount()}}); 4 5 // Perform prediction 6 INDArray output = model.output(input); 7 8 // Interpret the output 9 // Index 1 corresponds to fraud class (assuming one-hot encoding) 10 return output.getDouble(0, 1) > 0.5; 11 }
We can use this if we want to generate some synthetic transactions to test our model, once it is trained.
Now, it is time to put all the pieces in action. We'll create
FraudDetectionApp
to hold all our components.1 package com.mongodb; 2 3 import java.io.IOException; 4 import java.util.Collections; 5 import java.util.List; 6 import java.util.Random; 7 8 public class FraudDetectionApp { 9 private static FraudDetectionModel fraudDetectionModel; 10 11 public static void main(String[] args) { 12 MongoDBConnector mongoDBConnector = new MongoDBConnector(); 13 TransactionRepository transactionRepository = new TransactionRepository(mongoDBConnector); 14 DataPreprocessor preprocessor = new DataPreprocessor(transactionRepository); 15 16 try { 17 // Load and prepare training data, and add to MongoDB 18 preprocessor.loadData("src/main/resources/creditcard.csv"); 19 List<Transaction> transactions = transactionRepository.getAllTransactions(); 20 21 // Shuffle and split data 22 Collections.shuffle(transactions); 23 int trainSize = (int) (transactions.size() * 0.8); 24 List<Transaction> trainSet = transactions.subList(0, trainSize); 25 List<Transaction> testSet = transactions.subList(trainSize, transactions.size()); 26 27 System.out.println("Train size: " + trainSet.size() + ", Test size: " + testSet.size()); 28 29 // Create and train fraud detection model 30 fraudDetectionModel = new FraudDetectionModel(); 31 fraudDetectionModel.trainModel(trainSet); 32 33 // Evaluate model 34 fraudDetectionModel.evaluateModel(testSet); 35 36 } catch (IOException e) { 37 e.printStackTrace(); 38 } 39 } 40 }
Here, we load our data into our MongoDB database. Next, we shuffle and split our data into training and testing data. We then train our model and test the model with the remaining data.
Now, let's build and run our application, and take a look at our results.
Let's dive into what our output means, and what we can do with this information.
1 ========================Evaluation Metrics======================== 2 # of classes: 2 3 Accuracy: 0.8441 4 Precision: 0.0011 5 Recall: 0.1053 6 F1 Score: 0.0022 7 Precision, recall & F1: reported for positive class (class 1 - "1") only
- Accuracy: Proportion of all correct predictions.
1 double calculateAccuracy(int TP, int TN, int FP, int FN) { 2 int totalPredictions = TP + TN + FP + FN; 3 double accuracy = (double)(TP + TN) / totalPredictions; 4 return accuracy; 5 }
- Precision: Proportion of correct positive predictions.
1 double calculatePrecision(int TP, int FP) { 2 double precision = (double) TP / (TP + FP); 3 return precision; 4 }
- Recall: Proportion of actual positives correctly predicted.
1 double calculateRecall(int TP, int FN) { 2 double recall = (double) TP / (TP + FN); 3 return recall; 4 }
- F1 Score: Harmonic mean of Precision and Recall.
1 double calculateF1Score(double precision, double recall) { 2 double f1Score = 2 * (precision * recall) / (precision + recall); 3 return f1Score; 4 }
If you see, the accuracy is reasonably high, but why is the precision so low? Well, it's because we can classify most results as not fraudulent and still be correct, even if it is a poorly informed classification.
The confusion matrix summarizes the performance of a classification model by showing the counts of actual versus predicted classifications. Here's how to interpret the confusion matrix provided:
1 =========================Confusion Matrix========================= 2 0 1 3 ----------- 4 9615 1759 | 0 = 0 5 17 2 | 1 = 1 6 7 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times 8 ==================================================================
- Rows represent the actual classes (ground truth):
- The first row corresponds to actual class
0
. - The second row corresponds to actual class
1
.
- Columns represent the predicted classes:
- The first column corresponds to predicted class
0
. - The second column corresponds to predicted class
1
.
- 9615: Number of times the model correctly predicted
0
when the actual class was0
(True Negatives, TN) - 1759: Number of times the model incorrectly predicted
1
when the actual class was0
(False Positives, FP) - 17: Number of times the model incorrectly predicted
0
when the actual class was1
(False Negatives, FN) - 2: Number of times the model correctly predicted
1
when the actual class was1
(True Positives, TP)
This confusion matrix suggests that the model is heavily biased towards class
0
, as most predictions are for class 0
. This is common in datasets with severe class imbalance.Well, when I reorganised to use all features available, I ended up with much better accuracy! As a tradeoff, it did not declare any transaction as fraud. Not very useful. So how can this be resolved?
1 ========================Evaluation Metrics======================== 2 # of classes: 2 3 Accuracy: 0.9986 4 Precision: 0.0000 5 Recall: 0.0000 6 F1 Score: 0.0000 7 Precision, recall & F1: reported for positive class (class 1 - "1") only 8 9 Warning: 1 class was never predicted by the model and was excluded from average precision 10 Classes excluded from average precision: [1] 11 12 =========================Confusion Matrix========================= 13 0 1 14 ------------- 15 11377 0 | 0 = 0 16 16 0 | 1 = 1 17 18 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times 19 ==================================================================
Well, by actually addressing the way we are handling our data. We need to understand the differences between our fraudulent transactions and our credible ones. There is a whole host of research on this topic, and you can learn more about some of the methods in articles like What is Imbalanced Data and How to Handle It.
In this tutorial, we’ve walked through the steps to create a fraud detection system using Deeplearning4J and MongoDB. AI is a hot topic, and using DeepLearning4j, you can integrate AI into your Java applications.
If you want to learn more about what you can do with MongoDB and Java in the AI world, check out How to Deploy Vector Search, Atlas Search, and Search Nodes With the Atlas Kubernetes Operator.
Top Comments in Forums
There are no comments on this article yet.