Explore Developer Center's New Chatbot! MongoDB AI Chatbot can be accessed at the top of your navigation to answer all your MongoDB questions.

Learn why MongoDB was selected as a leader in the 2024 Gartner® Magic Quadrant™
MongoDB Developer
Java
plus
Sign in to follow topics
MongoDB Developer Center
chevron-right
Developer Topics
chevron-right
Languages
chevron-right
Java
chevron-right

How To Build a Fraud Detection Model in Java Using Deeplearning4J

Tim Kelly15 min read • Published Jan 27, 2025 • Updated Jan 27, 2025
AIJava
FULL APPLICATION
Facebook Icontwitter iconlinkedin icon
Rate this tutorial
star-empty
star-empty
star-empty
star-empty
star-empty
In this tutorial, we’ll learn how to combine the power of neural networks and scalable data management to build a real-world fraud detection system. AI isn’t just for JavaScript and Python anymore. Using Deeplearning4J, we’ll train a neural network in Java, and with MongoDB, we’ll manage and store transaction data efficiently.
Given a whole host of data on customer transactions, we are going to teach or model how to spot the frauds. Whether it’s a suspicious amount, odd location, or peculiar time, a lot of variables are at play, and it’s important to get it right! We’ll preprocess our transaction data and train a neural network to integrate MongoDB for scalable data storage. By the end, we’ll have a fully functional application capable of identifying fraudulent transactions.
If you want to clone this repo, or just view all the code, check out the GitHub repository.

What we’ll build

We’ll create a system that:
  1. Loads and preprocesses transaction data from a CSV file.
  2. Trains a neural network to classify transactions as fraudulent or non-fraudulent.
  3. Stores and retrieves transaction data in MongoDB, ensuring scalability and persistence.

Why Java for AI?

Java often flies under the radar in the AI world, but Java has some undeniable strengths that make it a solid choice for building AI systems, especially when we need to go beyond experiments and into our scalable production environments.
  • Integration with enterprise systems: Java is a cornerstone in enterprise software. When we want to embed AI into existing systems, Java’s prevalence makes it a natural fit.
  • Performance and scalability: With its multithreading and JVM optimizations, Java handles high-performance and distributed applications with elegance and grace (hyperbole)—perfect for real-world AI workloads.
By using Java for AI, we’re not just building something cool—we’re creating something robust, scalable, and ready to handle the demands of the real production world. Enough yapping about why it is cool, let’s actually start building something.

What we’ll cover

  1. Data preprocessing: How to clean and prepare transaction data for training
  2. Training the neural network: Using Deeplearning4J to train a classification model
  3. Integrating MongoDB: Managing transaction data in a scalable and efficient database
  4. Real-time interaction: Building an interactive CLI for testing our model and running predictions. This can just as easily be an API, or any other way you plan on interacting with your application.

Prerequisites

To follow along with this tutorial, you should have the following:

What is Deeplearning4J?

Deeplearning4J is a Java-based deep learning library designed for developers who want to build production-ready AI systems without leaving the JVM ecosystem. Unlike some of the more experimental tools in the AI world, Deeplearning4J is much more mature and focuses on real-world use cases, offering features that integrate with enterprise applications.
At the risk of sounding like an ad campaign for it, here’s what makes it stand out:
  • Native Java support: It’s built for Java developers, so we can use familiar tools and workflows to create deep learning models (like MongoDB).
  • Scalability: Deeplearning4J supports distributed training right out of the box, making it ideal for large datasets and high-performance applications.
  • Flexibility: Whether we’re working on classification, regression, or more complex architectures, the library provides the building blocks to customize and optimize our models.
  • Integration: Deeplearning4J plays nicely with other Java tools, including Apache Spark for big data processing and Hadoop for distributed systems.
Deeplearning4j stands out by being about more than just experimenting with AI in Java—we’re building models that can be deployed directly into production environments, alongside the rest of our Java applications.

Understanding the data

The dataset we’re working with is Credit Card Fraud available on Kaggle. It is quite a widely recognized credit card fraud detection dataset, containing anonymized transactional data. It provides a solid foundation for training our machine learning model to detect fraudulent transactions.

Exploring Our Credit Card Fraud dataset

This dataset has the following key characteristics:
  1. Size and composition:
    • Total transactions: 284,807
    • Fraudulent transactions: 492 (0.172% of the data)
    • The dataset is highly imbalanced, as fraudulent transactions represent a very small percentage of the total. This reflects the real-world challenge where fraud is rare but critical to detect.
  2. Features:
    • Time: The seconds elapsed between a transaction and the first transaction in the dataset.
    • Amount: The monetary value of the transaction.
    • V1 to V28: Principal components resulting from PCA (Principal Component Analysis), used to anonymize sensitive data.
    • Class: The target variable, where 0 indicates a legitimate transaction and 1 indicates fraud.

Acknowledging the class imbalance

The extreme class imbalance (fraudulent transactions make up only 0.172% of the dataset) can complicate machine learning:
  • Models can become biased toward the majority class (0), leading to poor performance on the minority class (1).
For this tutorial, we won’t focus on addressing the imbalance. Instead, we’ll train and evaluate the model on the original dataset, treating it as a learning exercise. However, techniques like resampling, weighted loss functions, or cost-sensitive learning would be essential to deal with an imbalanced dataset.

Preparing the data

To get the data ready for our model, we’ll follow these steps:
  1. Selecting features:
    • We’ll start with the Amount feature as our input.
    • The Time feature and PCA components (V1 to V28) will be excluded for simplicity, but they can be explored later to enhance the model. I run a test with these later, and you'll be able to see why I decided to exclude them for our simple implementation.
  2. Target variable:
    • The Class column serves as our target variable, with 0 for legitimate transactions and 1 for fraud.
  3. Normalizing data:
    • Neural networks perform better with consistent input scaling, so we’ll normalize the Amount column to a range of 0 to 1.
  4. Splitting the data:
    • We’ll divide the dataset into a training set (80%) and a test set (20%) to evaluate how well our model generalizes to unseen data.
By keeping the preprocessing straightforward, we can focus on understanding the mechanics of building and training a neural network. While we’re not addressing the class imbalance directly in this tutorial, it remains an absolutely critical consideration for real-world applications. I will reiterate this point a couple times throughout this tutorial, but that’s because it is probably the most influential factor on our results in this model. Let’s move on to the actual preprocessing code!

Our dependencies

First, we need to add our dependencies to our POM.
1<dependencies>
2 <dependency>
3 <groupId>org.deeplearning4j</groupId>
4 <artifactId>deeplearning4j-core</artifactId>
5 <version>1.0.0-M2.1</version>
6 </dependency>
7 <dependency>
8 <groupId>org.nd4j</groupId>
9 <artifactId>nd4j-native-platform</artifactId>
10 <version>1.0.0-M2.1</version>
11 </dependency>
12 <dependency>
13 <groupId>org.mongodb</groupId>
14 <artifactId>mongodb-driver-sync</artifactId>
15 <version>5.2.0</version>
16 </dependency>
17 <dependency>
18 <groupId>org.slf4j</groupId>
19 <artifactId>slf4j-simple</artifactId>
20 <version>2.0.16</version>
21 </dependency>
22
23</dependencies>
Here, we are importing a few dependencies for Deeplearning4j, as well as the MongoDB Java driver. We also have slf4j just for logging, and nd4j. Nd4j provides convenience methods for the creation of arrays from Java float and double arrays.

Setting up our MongoDB connection

Create a MongoDBConnector class. This will be how we establish our connection to our MongoDB database, where we will store our data.
For the URI, add your connection string for your database. Change your database and collection name to whatever makes sense, but I'm using fraudDection and transactions, respectively, for this example.
1package com.mongodb;
2
3import com.mongodb.client.MongoClient;
4import com.mongodb.client.MongoClients;
5import com.mongodb.client.MongoCollection;
6import com.mongodb.client.MongoDatabase;
7import org.bson.Document;
8
9import java.util.concurrent.TimeUnit;
10
11public class MongoDBConnector {
12 private static final String URI = "YOUR-CONNECTION-STRING";
13 private static final String DATABASE_NAME = "fraudDetection";
14 private static final String COLLECTION_NAME = "transactions";
15
16 private final MongoClient mongoClient;
17
18 public MongoDBConnector() {
19 MongoClientSettings settings = MongoClientSettings.builder()
20 .applyConnectionString(new ConnectionString(URI))
21 .applyToSocketSettings(builder ->
22 builder.connectTimeout(30, TimeUnit.SECONDS)
23 .readTimeout(30, TimeUnit.SECONDS))
24 .build();
25
26 mongoClient = MongoClients.create(settings);
27 }
28
29 public MongoCollection<Document> getCollection() {
30 return mongoClient.getDatabase(DATABASE_NAME).getCollection(COLLECTION_NAME);
31 }
32
33}
We are also applying specific timeout settings to our MongoClient. With us uploading a large amount of data to our db in large chunks, this will help us work around any timeout exceptions caused by networking issues.

Creating the transaction POJO

To work with our transaction data, we will create a Transaction class. We are going to simplify it for this demo, and just use the amount and whether it has been marked as fraudulent or not.
This will significantly affect the reliability of our model, so feel free to add more features for classifying as you want. AI is a complex world, and training models is a real skill. Different features will be more reliable for making fraudulent predictions. Think of large transactions on the far side of the globe, or during odd hours for that particular user.
1package com.mongodb;
2
3import org.bson.Document;
4
5public class Transaction {
6 private double amount;
7 private boolean isFraudulent;
8
9 public Transaction(double amount, boolean isFraudulent) {
10 this.amount = amount;
11 this.isFraudulent = isFraudulent;
12 }
13
14 public double getAmount() { return amount; }
15 public boolean isFraudulent() { return isFraudulent; }
16
17 public Document toDocument() {
18 Document doc = new Document();
19 doc.append("amount", amount);
20 doc.append("isFraudulent", isFraudulent);
21 return doc;
22 }
23}
This will provide a simple structure for our data and allow us to convert it into a MongoDB-friendly format with the toDocument method.

Storing and saving our data

Next, we need a TransactionRepository class. This will be where we encapsulate the operations for saving and fetching our transactions from the database. This abstraction lets us keep our data access logic organized and reusable.
1package com.mongodb;
2
3import com.mongodb.client.MongoCollection;
4import com.mongodb.client.model.BulkWriteOptions;
5import com.mongodb.client.model.WriteModel;
6import org.bson.Document;
7
8import java.util.ArrayList;
9import java.util.List;
10
11public class TransactionRepository {
12 private final MongoCollection<Document> collection;
13
14 public TransactionRepository(MongoDBConnector connector) {
15 this.collection = connector.getCollection();
16 }
17
18 public void bulkSaveTransactions(List<WriteModel<Document>> transactions) {
19 if (!transactions.isEmpty()) {
20 try {
21 BulkWriteOptions options = new BulkWriteOptions().ordered(true);
22 collection.bulkWrite(transactions, options); // Perform bulk write operation
23 } catch (Exception e) {
24 e.printStackTrace();
25 }
26 }
27 }
28
29 public List<Transaction> getAllTransactions() {
30 List<Transaction> transactions = new ArrayList<>();
31 for (Document doc : collection.find()) {
32 double amount = doc.getDouble("amount");
33 boolean isFraudulent = doc.getBoolean("isFraudulent");
34 transactions.add(new Transaction(amount, isFraudulent));
35 }
36 return transactions;
37 }
38
39}
This repository pattern makes it easy to add more database operations, such as filtering our transactions or performing more advanced queries.

Preprocessing data in Java

Effective preprocessing of the data is an essential step in machine learning. Here, we are going to load the data from the creditcard.csv file, which should be placed in the src/main/resources directory. Read in the fields we want to focus on in our model training, and format it for our transaction model.
1package com.mongodb;
2
3import com.mongodb.client.model.WriteModel;
4import com.mongodb.client.model.InsertOneModel;
5import org.bson.Document;
6
7import java.io.BufferedReader;
8import java.io.FileReader;
9import java.io.IOException;
10import java.util.ArrayList;
11import java.util.List;
12import java.util.stream.Collectors;
13
14public class DataPreprocessor {
15
16 TransactionRepository transactionRepository;
17
18 private static final int BATCH_SIZE = 500; // Define batch size for bulk writes
19 private static final int DOCUMENT_LIMIT = 250000; // Maximum documents to insert
20
21 private int documentCount = 0; // Counter for total inserted documents
22
23 public DataPreprocessor(TransactionRepository transactionRepository) {
24 this.transactionRepository = transactionRepository;
25 }
26
27 public void loadData(String filePath) throws IOException {
28 try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
29 // Skip the header by reading the first line
30 reader.readLine();
31
32 List<String> batch = new ArrayList<>();
33 String line;
34
35 // Read the file line-by-line
36 while ((line = reader.readLine()) != null) {
37 if (documentCount >= DOCUMENT_LIMIT) {
38 System.out.println("Reached the document limit of " + DOCUMENT_LIMIT + ". Stopping data load.");
39 break; // Stop processing when the limit is reached
40 }
41
42 batch.add(line);
43 documentCount++;
44
45 // When batch size is reached, process it
46 if (batch.size() == BATCH_SIZE) {
47 processBatch(batch);
48 batch.clear(); // Clear the batch for the next set of lines
49 }
50 }
51
52 // Process any remaining lines
53 if (!batch.isEmpty()) {
54 processBatch(batch);
55 }
56 }
57 }
58
59 private void processBatch(List<String> batch) {
60 List<WriteModel<Document>> bulkOperations = batch.stream()
61 .map(line -> {
62 String[] fields = line.split(",");
63 double amount = Double.parseDouble(fields[29]); // Adjust index as needed
64 boolean isFraudulent = "1".equals(fields[30]);
65 Transaction transaction = new Transaction(amount, isFraudulent);
66 return new InsertOneModel<>(transaction.toDocument());
67 })
68 .collect(Collectors.toList());
69 transactionRepository.bulkSaveTransactions(bulkOperations);
70 }
71
72}
To handle the large datasets efficiently, we are using MongoDB’s bulkWrite operation. This reduces the number of network round trips from our application to our MongoDB instance which increases the performance of our application.
We are also limiting the number of documents added to the database to 250,000. This is because we are using the MongoDB M0 tier cluster. It is limited to 500MB and we do not want to exceed this.

Building a neural network with Deeplearning4J

To detect the fraudulent transactions, we'll use Deeplearing4J to build our neural network. We'll do this in a FraudDetectionModel class.

Setting up the neural network

We will set up a simple feedforward network to specify whether a transaction is fraudulent or not.
  • Input layer (implicit):
    Every neural network begins with an input layer that receives the data. In Deeplearning4J, this is handled automatically by the first layer. The nIn(NUM_INPUT_FEATURES) in the first DenseLayer specifies how many input features the model will take. For now, we’re using only the Amount feature, but you can easily expand this.
  • Dense layer (hidden layer):
    This is the core of our model where learning happens. It consists of 10 neurons and uses the ReLU (Rectified Linear Unit) activation function, which introduces non-linearity to help the model learn complex patterns in the data.
  • Output layer:
    The output layer predicts whether a transaction is fraudulent or not. It has two neurons—one for each class (fraudulent or legitimate)—and uses the Softmax activation function to generate probabilities for each class. The class with the highest probability becomes the model’s prediction.
  • Loss function and optimizer:
    • The model uses the Negative Log Likelihood loss function, ideal for classification problems.
    • It’s optimized using the Adam optimizer with a learning rate of 0.001, which adapts the learning rate throughout training for better convergence.
    • Xavier weight initialization is used to keep the starting weights balanced, preventing issues like vanishing or exploding gradients.
  • Training feedback:
    We added a ScoreIterationListener that outputs the model’s training progress every 10 iterations, giving us insight into how well the model is learning.
If none of what I just said makes any sense, don't worry—it didn't for me when I started either. I've shown you how to implement it below, so you can see it in action. AI is a large field, and a lot of brilliant people have done amazing work to make it more accessible. Check out MongoDB's Developer Center to learn more about AI.
1package com.mongodb;
2
3import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
4import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
5import org.deeplearning4j.nn.conf.layers.DenseLayer;
6import org.deeplearning4j.nn.conf.layers.OutputLayer;
7import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
8import org.deeplearning4j.nn.weights.WeightInit;
9import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
10import org.nd4j.evaluation.classification.Evaluation;
11import org.nd4j.linalg.activations.Activation;
12import org.nd4j.linalg.api.ndarray.INDArray;
13import org.nd4j.linalg.dataset.DataSet;
14import org.nd4j.linalg.factory.Nd4j;
15import org.nd4j.linalg.learning.config.Adam;
16import org.nd4j.linalg.lossfunctions.LossFunctions;
17
18import java.util.Collections;
19import java.util.List;
20
21public class FraudDetectionModel {
22 private MultiLayerNetwork model;
23
24 private static final int NUM_INPUT_FEATURES = 1;
25 private static final int NUM_CLASSES = 2;
26
27 public FraudDetectionModel() {
28 initializeModel();
29 }
30
31 private void initializeModel() {
32 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
33 .seed(123)
34 .weightInit(WeightInit.XAVIER)
35 .updater(new Adam(0.001))
36 .list()
37 .layer(0, new DenseLayer.Builder()
38 .nIn(NUM_INPUT_FEATURES)
39 .nOut(10)
40 .activation(Activation.RELU)
41 .build())
42 .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
43 .nIn(10)
44 .nOut(NUM_CLASSES)
45 .activation(Activation.SOFTMAX)
46 .build())
47 .build();
48
49 model = new MultiLayerNetwork(conf);
50 model.init();
51 model.setListeners(new ScoreIterationListener(10));
52 }
53
54}
So what does all this mean? Well, let’s explore the components of this class and what each part does.

Class variables

1private MultiLayerNetwork model;
2private static final int NUM_INPUT_FEATURES = 1;
3private static final int NUM_CLASSES = 2;
model: The neural network instance that will be trained
  • NUM_INPUT_FEATURES: The number of input features—currently set to 1 because only the transaction amount is used
  • NUM_CLASSES: Set to 2, representing two possible outcomes: legitimate (0) or fraudulent (1)

Model initialization

1private void initializeModel() {
2 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
3 .seed(123)
4 .weightInit(WeightInit.XAVIER)
5 .updater(new Adam(0.001))
6 .list()
  • NeuralNetConfiguration.Builder(): Used to define the network architecture.
  • seed(123): Sets a random seed for reproducibility.
  • weightInit(WeightInit.XAVIER): Initializes the network weights using Xavier initialization, which balances the variance of weights, helping the network train efficiently.
  • updater(new Adam(0.001)): Uses the Adam optimizer with a learning rate of 0.001 for adaptive learning

Hidden layer (dense layer)

1.layer(0, new DenseLayer.Builder()
2 .nIn(NUM_INPUT_FEATURES)
3 .nOut(10)
4 .activation(Activation.RELU)
5 .build())
  • Layer 0: The first hidden layer
  • nIn(NUM_INPUT_FEATURES): Number of inputs, currently 1 (transaction amount)
  • nOut(10): Number of neurons in this layer (10 neurons)
  • Activation.RELU: ReLU activation introduces non-linearity, allowing the model to learn complex relationships

Output layer

1.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
2 .nIn(10)
3 .nOut(NUM_CLASSES)
4 .activation(Activation.SOFTMAX)
5 .build())
  • Layer 1: The output layer
  • LossFunction.NEGATIVELOGLIKELIHOOD: A loss function suitable for classification tasks
  • nIn(10): Takes the 10 outputs from the hidden layer
  • nOut(NUM_CLASSES): Outputs two values representing the probabilities of each class (fraudulent or legitimate)
  • Activation.SOFTMAX: Converts outputs into class probabilities that sum up to 1.

Model initialization and training feedback

1model = new MultiLayerNetwork(conf);
2model.init();
3model.setListeners(new ScoreIterationListener(10));
  • model = new MultiLayerNetwork(conf);: Creates the neural network using the configuration
  • model.init();: Initializes the network’s parameters
  • setListeners(new ScoreIterationListener(10));: Logs the model's score (error) every 10 iterations to track training progress
Next, we'll add the method prepareTrainingData to our class. The network expects data to be in the form of features (inputs) and labels (targets).
1 private DataSet prepareTrainingData(List<Transaction> transactions) {
2 int numTransactions = transactions.size();
3 INDArray features = Nd4j.create(numTransactions, NUM_INPUT_FEATURES);
4 INDArray labels = Nd4j.create(numTransactions, NUM_CLASSES);
5 DataSet dataSet;
6
7 for (int i = 0; i < numTransactions; i++) {
8 Transaction transaction = transactions.get(i);
9
10 // Use transaction amount as feature
11 features.putScalar(new int[]{i, 0}, transaction.getAmount());
12
13 // One-hot encoding for labels
14 if (transaction.isFraudulent()) {
15 labels.putScalar(new int[]{i, 1}, 1.0);
16 labels.putScalar(new int[]{i, 0}, 0.0);
17 } else {
18 labels.putScalar(new int[]{i, 0}, 1.0);
19 labels.putScalar(new int[]{i, 1}, 0.0);
20 }
21 }
22
23 return dataSet = new DataSet(features, labels);
24 }
We train the model with a simple loop over multiple epochs. An epoch is a term in machine learning which refers to a complete pass of the complete dataset through the learning algorithm.
1 public void trainModel(List<Transaction> transactions) {
2 // Shuffle the data to ensure random distribution
3 Collections.shuffle(transactions);
4
5 // Prepare training data
6 DataSet dataSet = prepareTrainingData(transactions);;
7
8 // Train the model
9 for (int epoch = 0; epoch < 100; epoch++) {
10 model.fit(dataSet);
11 }
12
13 System.out.println("Model trained successfully.");
14 }
After the training, we'll add a method evaluateModel to measure the models performance. This is crucial in AI model training, in order to refine our methodology for how we want to implement the neural network.
1public void evaluateModel(List<Transaction> transactions) {
2 // Split data into train and test sets
3 int trainSize = (int)(transactions.size() * 0.8);
4 List<Transaction> testSet = transactions.subList(trainSize, transactions.size());
5
6 // Prepare test data
7 DataSet testData = prepareTrainingData(testSet);
8
9 // Perform evaluation
10 Evaluation evaluation = new Evaluation(NUM_CLASSES);
11 INDArray predicted = model.output(testData.getFeatures());
12 evaluation.eval(testData.getLabels(), predicted);
13
14 // Print evaluation statistics
15 System.out.println(evaluation.stats());
16}
We'll see this in our main application later, and we'll learn how to interpret our results.
Lastly, we'll add a method to predict fraud for a singular transaction.
1 public boolean predictFraud(Transaction transaction) {
2 // Convert transaction to INDArray
3 INDArray input = Nd4j.create(new double[][]{{transaction.getAmount()}});
4
5 // Perform prediction
6 INDArray output = model.output(input);
7
8 // Interpret the output
9 // Index 1 corresponds to fraud class (assuming one-hot encoding)
10 return output.getDouble(0, 1) > 0.5;
11 }
We can use this if we want to generate some synthetic transactions to test our model, once it is trained.

Fraud detection model in action

Now, it is time to put all the pieces in action. We'll create FraudDetectionApp to hold all our components.
1package com.mongodb;
2
3import java.io.IOException;
4import java.util.Collections;
5import java.util.List;
6import java.util.Random;
7
8public class FraudDetectionApp {
9 private static FraudDetectionModel fraudDetectionModel;
10
11 public static void main(String[] args) {
12 MongoDBConnector mongoDBConnector = new MongoDBConnector();
13 TransactionRepository transactionRepository = new TransactionRepository(mongoDBConnector);
14 DataPreprocessor preprocessor = new DataPreprocessor(transactionRepository);
15
16 try {
17 // Load and prepare training data, and add to MongoDB
18 preprocessor.loadData("src/main/resources/creditcard.csv");
19 List<Transaction> transactions = transactionRepository.getAllTransactions();
20
21 // Shuffle and split data
22 Collections.shuffle(transactions);
23 int trainSize = (int) (transactions.size() * 0.8);
24 List<Transaction> trainSet = transactions.subList(0, trainSize);
25 List<Transaction> testSet = transactions.subList(trainSize, transactions.size());
26
27 System.out.println("Train size: " + trainSet.size() + ", Test size: " + testSet.size());
28
29 // Create and train fraud detection model
30 fraudDetectionModel = new FraudDetectionModel();
31 fraudDetectionModel.trainModel(trainSet);
32
33 // Evaluate model
34 fraudDetectionModel.evaluateModel(testSet);
35
36 } catch (IOException e) {
37 e.printStackTrace();
38 }
39 }
40}
Here, we load our data into our MongoDB database. Next, we shuffle and split our data into training and testing data. We then train our model and test the model with the remaining data.
Now, let's build and run our application, and take a look at our results.

Understanding our evaluation metrics

Let's dive into what our output means, and what we can do with this information.
1========================Evaluation Metrics========================
2 # of classes: 2
3 Accuracy: 0.8441
4 Precision: 0.0011
5 Recall: 0.1053
6 F1 Score: 0.0022
7Precision, recall & F1: reported for positive class (class 1 - "1") only
  • Accuracy: Proportion of all correct predictions.
1double calculateAccuracy(int TP, int TN, int FP, int FN) {
2 int totalPredictions = TP + TN + FP + FN;
3 double accuracy = (double)(TP + TN) / totalPredictions;
4 return accuracy;
5}
  • Precision: Proportion of correct positive predictions.
1double calculatePrecision(int TP, int FP) {
2 double precision = (double) TP / (TP + FP);
3 return precision;
4}
  • Recall: Proportion of actual positives correctly predicted.
1double calculateRecall(int TP, int FN) {
2 double recall = (double) TP / (TP + FN);
3 return recall;
4}
  • F1 Score: Harmonic mean of Precision and Recall.
1double calculateF1Score(double precision, double recall) {
2 double f1Score = 2 * (precision * recall) / (precision + recall);
3 return f1Score;
4}
If you see, the accuracy is reasonably high, but why is the precision so low? Well, it's because we can classify most results as not fraudulent and still be correct, even if it is a poorly informed classification.
The confusion matrix summarizes the performance of a classification model by showing the counts of actual versus predicted classifications. Here's how to interpret the confusion matrix provided:
1=========================Confusion Matrix=========================
2 0 1
3-----------
4 9615 1759 | 0 = 0
5 17 2 | 1 = 1
6
7Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
8==================================================================
  • Rows represent the actual classes (ground truth):
    • The first row corresponds to actual class 0.
    • The second row corresponds to actual class 1.
  • Columns represent the predicted classes:
    • The first column corresponds to predicted class 0.
    • The second column corresponds to predicted class 1.

Values

  • 9615: Number of times the model correctly predicted 0 when the actual class was 0 (True Negatives, TN)
  • 1759: Number of times the model incorrectly predicted 1 when the actual class was 0 (False Positives, FP)
  • 17: Number of times the model incorrectly predicted 0 when the actual class was 1 (False Negatives, FN)
  • 2: Number of times the model correctly predicted 1 when the actual class was 1 (True Positives, TP)
This confusion matrix suggests that the model is heavily biased towards class 0, as most predictions are for class 0. This is common in datasets with severe class imbalance.
Well, when I reorganised to use all features available, I ended up with much better accuracy! As a tradeoff, it did not declare any transaction as fraud. Not very useful. So how can this be resolved?
1========================Evaluation Metrics========================
2 # of classes: 2
3 Accuracy: 0.9986
4 Precision: 0.0000
5 Recall: 0.0000
6 F1 Score: 0.0000
7Precision, recall & F1: reported for positive class (class 1 - "1") only
8
9Warning: 1 class was never predicted by the model and was excluded from average precision
10Classes excluded from average precision: [1]
11
12=========================Confusion Matrix=========================
13 0 1
14-------------
15 11377 0 | 0 = 0
16 16 0 | 1 = 1
17
18Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
19==================================================================
Well, by actually addressing the way we are handling our data. We need to understand the differences between our fraudulent transactions and our credible ones. There is a whole host of research on this topic, and you can learn more about some of the methods in articles like What is Imbalanced Data and How to Handle It.

Conclusion

In this tutorial, we’ve walked through the steps to create a fraud detection system using Deeplearning4J and MongoDB. AI is a hot topic, and using DeepLearning4j, you can integrate AI into your Java applications.
If you want to learn more about what you can do with MongoDB and Java in the AI world, check out How to Deploy Vector Search, Atlas Search, and Search Nodes With the Atlas Kubernetes Operator.
Top Comments in Forums
There are no comments on this article yet.
Start the Conversation

Facebook Icontwitter iconlinkedin icon
Rate this tutorial
star-empty
star-empty
star-empty
star-empty
star-empty
Related
Tutorial

Spring Data Unlocked: Getting Started With Java and MongoDB


Nov 11, 2024 | 5 min read
Tutorial

Spring Data Unlocked: Advanced Queries With MongoDB


Nov 08, 2024 | 7 min read
Tutorial

Single-Collection Designs in MongoDB with Spring Data (Part 2)


Aug 12, 2024 | 10 min read
Tutorial

How to Migrate PostgreSQL to MongoDB With Confluent Kafka


Aug 30, 2024 | 10 min read
Table of Contents