Build a Self-Learning Web Application Firewall with Quarkus and DeepLearning4j
Explore how to create a self-learning Web Application Firewall without static rules, just smart, reactive Java security.
What if your Web Application Firewall (WAF) didn’t just block known threats, but actually learned what “normal” looked like, and evolved to stop new attacks on its own?
That idea has fascinated me for years. I’ve spent much of my career building and securing web applications, and along the way, I’ve developed a deep appreciation for application-layer security. WAF appliances have always intrigued me: the price tags, the rule engines, the sheer range of capabilities packed into a black box. But they’ve also left me wondering: Could I build one myself? Could a modern Java stack give me enough flexibility, performance, and control to roll my own intelligent WAF?
This tutorial is the result of that curiosity.
What you’ll build here is a smart, self-learning WAF powered by Quarkus and DeepLearning4j. It inspects every incoming request, extracts behavioral features, and runs them through a deep learning model to detect anomalies in real time. There are no static regex rules, no vendor lock-in. Just pure Java, fast inference, and a dash of machine learning.
A word of caution before we dive in: security teams have a saying—"never roll your own crypto"—because homebrewed encryption tends to end in tears, breaches, or both. The same wisdom often applies to WAFs or other security related components. While this tutorial shows what's possible, don’t mistake it for a certified, production-hardened defense layer. Think of it like building your own car’s brakes for fun: great for learning, bad for highways.
It’s a powerful demonstration of what’s possible when you combine Quarkus’ reactive runtime with real-time anomaly detection. And it might just inspire you to rethink how we approach web application defense in the age of AI.
Let’s build it from scratch.
What We’re Building
You’ll implement a complete security filter that:
Intercepts every HTTP request via a JAX-RS filter
Extracts behavioral features (like URL entropy, HTTP method, suspicious patterns)
Uses a DeepLearning4j autoencoder to detect anomalous requests
Decides whether to ALLOW, BLOCK, RATE_LIMIT, or LOG_ONLY
Retrains itself periodically to adapt to real traffic
Offers a management API for insights and tuning
Prerequisites
JDK 17+
Apache Maven 3.8+
Your favorite IDE (IntelliJ IDEA, VS Code)
cURL or HTTPie for testing
Optional:
jq
to pretty-print JSON output
Project Setup
Start by generating a fresh Quarkus project:
mvn io.quarkus.platform:quarkus-maven-plugin:3.24.4:create \
-DprojectGroupId=org.example \
-DprojectArtifactId=self-learning-waf \
-DclassName="org.example.GreetingResource" \
-Dpath="/hello" \
-Dextensions="rest-jackson,scheduler"
cd self-learning-waf
Note on the version in the mvn command: Most of my tutorials use the latest available Quarkus version implicitly by omitting the plugin version. But specifying it explicitly, like
3.24.4
in the command above, is actually good practice. It ensures consistent builds across different environments and avoids unexpected behavior if the plugin or platform introduces breaking changes in future releases. So, if you see me skipping the version, make sure you are doing the right thing instead.
Then open pom.xml
and add DeepLearning4j dependencies:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
And as usual, you can find the complete project on my Github repository, ready to clone and play around with.
Architecture Overview
Let’s break it down:
Feature Extraction: Convert HTTP request metadata into a numerical vector (length, header size, entropy, etc.)
Autoencoder: Train a DL4J model on normal traffic. It learns to reconstruct “normal” requests with low error.
Detection: If a request’s reconstruction error is high, it’s flagged as anomalous.
Action: Based on thresholds and severity, we allow, block, or rate-limit.
This blend of rule-based checks and machine learning gives us robust coverage against both known and unknown threats.
Define Data Models
Create src/main/java/org/example/waf/WAFModels.java
:
package org.example.waf;
import java.util.Map;
enum WAFAction {
ALLOW, BLOCK, RATE_LIMIT, LOG_ONLY
}
record RestRequestInfo(
String method,
String path,
String query,
String clientIP,
long timestamp,
Map<String, String> headers,
String contentType,
int contentLength) {
}
record WAFDecision(
boolean isAnomaly,
WAFAction action,
double confidence,
double reconstructionError,
String reason) {
}
record WAFLogEntry(
long timestamp,
String clientIP,
String method,
String path,
WAFAction action,
double reconstructionError,
long processingTime,
String reason) {
}
We’re using Java records for clarity and immutability.
Build the WAF Filter
Now we implement the heart of our system: a JAX-RS filter with integrated DL4J inference. Create src/main/java/org/example/waf/SelfLearningWAFFilter.java
.
Start with annotations and fields:
@Provider
@ApplicationScoped
@PreMatching
public class SelfLearningWAFFilter implements ContainerRequestFilter {
private static final Logger logger = Logger.getLogger(SelfLearningWAFFilter.class.getName());
// Inject Vert.x request to get client IP robustly
@Context
HttpServerRequest vertxRequest;
// Model and configuration
private MultiLayerNetwork autoEncoder;
private final int inputSize = 25; // Size of our feature vector
private volatile double anomalyThreshold = 0.15;
// Thread-safe collections for tracking and training
private final Map<String, AtomicInteger> requestCounts = new ConcurrentHashMap<>();
private final Map<String, AtomicLong> lastRequestTime = new ConcurrentHashMap<>();
private final Queue<double[]> trainingBuffer = new ConcurrentLinkedQueue<>();
private final Queue<WAFLogEntry> auditLog = new ConcurrentLinkedQueue<>();
// WAF parameters
private final int bufferSize = 2000; // Max requests to hold for training
private final int maxAuditLogSize = 10000;
private final long rateLimitWindow = 60000; // 1 minute
private final int maxRequestsPerWindow = 100;
// Scheduler for background tasks
private ScheduledExecutorService scheduler;
private final AtomicInteger trainingCycle = new AtomicInteger(0);
Now initialize and clean up resources:
@PostConstruct
public void init() {
logger.info("Initializing Self-Learning WAF Filter...");
initializeModel();
scheduler = Executors.newScheduledThreadPool(2);
// Schedule periodic training to run every 5 minutes
scheduler.scheduleAtFixedRate(this::performIncrementalTraining, 5, 5, TimeUnit.MINUTES);
// Schedule data cleanup to run every hour
scheduler.scheduleAtFixedRate(this::cleanupOldData, 1, 1, TimeUnit.HOURS);
logger.info("Self-Learning WAF Filter initialized successfully.");
}
@PreDestroy
public void cleanup() {
logger.info("Shutting down WAF scheduler.");
if (scheduler != null) {
scheduler.shutdownNow();
}
}
Now add the filter()
method to handle incoming traffic:
@Override
public void filter(ContainerRequestContext requestContext) {
long startTime = System.currentTimeMillis();
try {
// 1. Extract request info
RestRequestInfo requestInfo = extractRequestInfo(requestContext);
// 2. Analyze the request
WAFDecision decision = analyzeRequest(requestInfo);
// 3. Log the outcome
logDecision(requestInfo, decision, System.currentTimeMillis() - startTime);
// 4. Act on the decision
switch (decision.action()) {
case BLOCK -> {
logger.warning("BLOCK: " + decision.reason() + " from IP " + requestInfo.clientIP());
abortWithSecurityResponse(requestContext, Response.Status.FORBIDDEN, "WAF_BLOCK",
"Request blocked by security policy");
}
case RATE_LIMIT -> {
logger.info("RATE_LIMIT: " + requestInfo.clientIP());
abortWithSecurityResponse(requestContext, Response.Status.TOO_MANY_REQUESTS, "RATE_LIMIT",
"Rate limit exceeded");
}
default -> {
// ALLOW or LOG_ONLY, let request proceed
}
}
} catch (Exception e) {
logger.log(Level.SEVERE, "Error in WAF filter, allowing request to proceed.", e);
}
}
Now add the core logic. This includes initializing the DL4J model, extracting features from the request, and analyzing those features to make a decision.
private void initializeModel() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.001))
.list()
// Encoder layers
.layer(new AutoEncoder.Builder().nIn(inputSize).nOut(20).activation(Activation.RELU).build())
.layer(new AutoEncoder.Builder().nIn(20).nOut(15).activation(Activation.RELU).build())
.layer(new AutoEncoder.Builder().nIn(15).nOut(8).activation(Activation.RELU).build()) // Bottleneck
// Decoder layers
.layer(new AutoEncoder.Builder().nIn(8).nOut(15).activation(Activation.RELU).build())
.layer(new AutoEncoder.Builder().nIn(15).nOut(20).activation(Activation.RELU).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.nIn(20).nOut(inputSize).activation(Activation.SIGMOID).build())
.build();
autoEncoder = new MultiLayerNetwork(conf);
autoEncoder.init();
}
private RestRequestInfo extractRequestInfo(ContainerRequestContext context) {
Map<String, String> headers = new HashMap<>();
context.getHeaders().forEach((key, values) -> {
if (!values.isEmpty())
headers.put(key, values.get(0));
});
String contentLength = context.getHeaderString("Content-Length");
return new RestRequestInfo(
context.getMethod(),
context.getUriInfo().getPath(),
context.getUriInfo().getRequestUri().getQuery(),
getClientIP(context.getHeaderString("X-Forwarded-For")),
System.currentTimeMillis(),
headers,
context.getMediaType() != null ? context.getMediaType().toString() : "",
contentLength != null ? Integer.parseInt(contentLength) : 0);
}
private String getClientIP(String xForwardedFor) {
if (xForwardedFor != null && !xForwardedFor.isBlank()) {
return xForwardedFor.split(",")[0].trim();
}
// Fallback to Vert.x request's remote address
return vertxRequest.remoteAddress().host();
}
private WAFDecision analyzeRequest(RestRequestInfo request) {
// 1. Check for rate limiting first
if (isRateLimited(request.clientIP(), request.timestamp())) {
return new WAFDecision(true, WAFAction.RATE_LIMIT, 1.0, -1, "Rate limit exceeded");
}
// 2. Check for obvious, rule-based attack patterns
if (isObviousAttack(request)) {
return new WAFDecision(true, WAFAction.BLOCK, 1.0, -1, "Known attack pattern detected");
}
// 3. Use the Autoencoder for anomaly detection
double[] features = extractFeatures(request);
addToTrainingBuffer(features); // Add to buffer for future training
INDArray input = Nd4j.create(features).reshape(1, inputSize);
INDArray reconstructed = autoEncoder.output(input);
double error = input.sub(reconstructed).mul(input).sumNumber().doubleValue() / inputSize;
boolean isAnomaly = error > anomalyThreshold;
double confidence = Math.min(error / anomalyThreshold, 2.0); // Cap confidence at 2.0
if (isAnomaly) {
String reason = "High-confidence anomaly. Error: " + String.format("%.4f", error);
// Block if confidence is very high, otherwise just log it
WAFAction action = (confidence > 1.5) ? WAFAction.BLOCK : WAFAction.LOG_ONLY;
return new WAFDecision(true, action, confidence, error, reason);
}
// 4. If nothing is suspicious, allow the request
return new WAFDecision(false, WAFAction.ALLOW, 0.0, error, "Request allowed");
}
private double[] extractFeatures(RestRequestInfo request) {
double[] features = new double[inputSize];
String fullUrl = request.path() + (request.query() != null ? "?" + request.query() : "");
Map<String, String> headers = request.headers();
// Basic features
features[0] = Math.min(fullUrl.length() / 250.0, 1.0);
features[1] = Math.min(countParams(request.query()) / 20.0, 1.0);
features[2] = hasSpecialChars(fullUrl) ? 1.0 : 0.0;
features[3] = hasSqlKeywords(fullUrl) ? 1.0 : 0.0;
features[4] = hasXssPatterns(fullUrl) ? 1.0 : 0.0;
features[5] = hasPathTraversal(fullUrl) ? 1.0 : 0.0;
features[6] = hasCommandInjection(fullUrl) ? 1.0 : 0.0;
// Method (one-hot encoding)
features[7] = request.method().equals("GET") ? 1.0 : 0.0;
features[8] = request.method().equals("POST") ? 1.0 : 0.0;
// Header features
features[12] = Math.min(headers.size() / 30.0, 1.0);
features[13] = hasSuspiciousUserAgent(headers.get("User-Agent")) ? 1.0 : 0.0;
// Content features
features[14] = Math.min(request.contentLength() / 10000.0, 1.0);
// Advanced features
features[22] = getEntropyScore(fullUrl);
features[23] = hasEncodedPayload(fullUrl) ? 1.0 : 0.0;
features[24] = Math.min(getPathDepth(request.path()) / 10.0, 1.0);
return features;
}
Finally, add the helper methods for pattern matching, rate limiting, training, and logging.
private boolean isObviousAttack(RestRequestInfo request) {
String fullUrl = request.path() + (request.query() != null ? "?" + request.query() : "");
String userAgent = request.headers().get("User-Agent");
return hasSqlKeywords(fullUrl) || hasXssPatterns(fullUrl) || hasPathTraversal(fullUrl)
|| hasCommandInjection(fullUrl) || hasSuspiciousUserAgent(userAgent);
}
private boolean isRateLimited(String clientIP, long timestamp) {
long now = System.currentTimeMillis();
lastRequestTime.put(clientIP, new AtomicLong(now));
AtomicInteger count = requestCounts.computeIfAbsent(clientIP, k -> new AtomicInteger(0));
// This is a simplified sliding window check
if (now - lastRequestTime.get(clientIP).get() > rateLimitWindow) {
count.set(1); // Reset count for new window
return false;
}
return count.incrementAndGet() > maxRequestsPerWindow;
}
private void performIncrementalTraining() {
if (trainingBuffer.size() < 100) {
logger.info("Skipping training cycle, not enough data.");
return; // Don't train on too few samples
}
try {
double[][] data = trainingBuffer.toArray(new double[0][]);
INDArray trainingData = Nd4j.create(data);
DataSet dataSet = new DataSet(trainingData, trainingData);
logger.info("Starting incremental training cycle #" + trainingCycle.incrementAndGet() + " with "
+ data.length + " samples.");
autoEncoder.fit(dataSet);
logger.info("Training complete. Current anomaly threshold: " + anomalyThreshold);
} catch (Exception e) {
logger.log(Level.WARNING, "Error during incremental training", e);
}
}
// Utility methods for aborting, logging, cleanup, and pattern matching
private void abortWithSecurityResponse(ContainerRequestContext context, Response.Status status, String code,
String message) {
String jsonPayload = String.format("{\"error\":\"%s\",\"code\":\"%s\"}", message, code);
Response response = Response.status(status)
.entity(jsonPayload)
.type("application/json")
.build();
context.abortWith(response);
}
private void logDecision(RestRequestInfo request, WAFDecision decision, long processingTime) {
WAFLogEntry logEntry = new WAFLogEntry(System.currentTimeMillis(), request.clientIP(), request.method(),
request.path(), decision.action(), decision.reconstructionError(), processingTime, decision.reason());
if (auditLog.size() >= maxAuditLogSize)
auditLog.poll();
auditLog.offer(logEntry);
if (decision.action() != WAFAction.ALLOW) {
logger.info("WAF Decision: " + logEntry);
}
}
private void addToTrainingBuffer(double[] features) {
if (trainingBuffer.size() >= bufferSize)
trainingBuffer.poll();
trainingBuffer.offer(features);
}
private void cleanupOldData() {
long oneHourAgo = System.currentTimeMillis() - 3_600_000;
lastRequestTime.entrySet().removeIf(entry -> entry.getValue().get() < oneHourAgo);
requestCounts.keySet().retainAll(lastRequestTime.keySet());
logger.info("Cleanup complete. Tracking " + requestCounts.size() + " active clients.");
}
// --- Pattern Matching Helpers ---
private boolean hasSpecialChars(String s) {
return s != null && s.matches(".*[<>'\"%;()&+].*");
}
private boolean hasSqlKeywords(String s) {
return s != null && s.toLowerCase().matches(".*(select|insert|union|script|exec|drop).*");
}
private boolean hasXssPatterns(String s) {
return s != null && s.toLowerCase().matches(".*(<script|javascript:|onload=|onerror=).*");
}
private boolean hasPathTraversal(String s) {
return s != null && s.contains("../");
}
private boolean hasCommandInjection(String s) {
return s != null && s.toLowerCase().matches(".*(;|\\|\\|?|&&)\\s*(cat|ls|pwd|whoami).*");
}
private boolean hasSuspiciousUserAgent(String ua) {
return ua != null && ua.toLowerCase().matches(".*(sqlmap|nikto|nessus|burp).*");
}
private boolean hasEncodedPayload(String s) {
return s != null && s.matches(".*(%[0-9a-fA-F]{2}|\\+).*");
}
private int countParams(String q) {
return q == null ? 0 : q.split("&").length;
}
private double getPathDepth(String p) {
return p == null ? 0 : p.chars().filter(ch -> ch == '/').count();
}
private double getEntropyScore(String text) {
if (text == null || text.isEmpty())
return 0.0;
Map<Character, Integer> freq = new HashMap<>();
text.chars().forEach(c -> freq.put((char) c, freq.getOrDefault((char) c, 0) + 1));
double entropy = freq.values().stream()
.mapToDouble(count -> (double) count / text.length())
.map(p -> -p * (Math.log(p) / Math.log(2)))
.sum();
return Math.min(entropy / 8.0, 1.0); // Normalize
}
// Public accessors for management endpoint
public Map<String, Object> getStatistics() {
Map<String, Object> stats = new HashMap<>();
stats.put("trackedClients", requestCounts.size());
stats.put("trainingBufferSize", trainingBuffer.size());
stats.put("auditLogSize", auditLog.size());
stats.put("anomalyThreshold", anomalyThreshold);
stats.put("trainingCycles", trainingCycle.get());
return stats;
}
public void reportFalsePositive() {
this.anomalyThreshold *= 1.05; // Slightly increase threshold
logger.info("False positive reported. New threshold: " + this.anomalyThreshold);
}
Add Management Endpoints
Create src/main/java/org/example/waf/WAFManagementResource.java
:
package org.example.waf;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Response;
@Path("/management/waf")
@ApplicationScoped
public class WAFManagementResource {
@Inject
SelfLearningWAFFilter wafFilter;
@GET
@Path("/stats")
@Produces("application/json")
public Response getStatistics() {
return Response.ok(wafFilter.getStatistics()).build();
}
@POST
@Path("/false-positive")
public Response reportFalsePositive() {
wafFilter.reportFalsePositive();
return Response.ok("Threshold adjusted.").build();
}
}
This gives you a way to tune the threshold or monitor activity without restarting the app.
Tweak the GreetingResource
package org.example;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.core.MediaType;
@Path("/hello")
public class GreetingResource {
@GET
@Produces(MediaType.TEXT_PLAIN)
public String hello(@QueryParam("name") String name) {
return "Hello from Quarkus REST " + name;
}
@GET
@Path("/{param: .+}")
@Produces(MediaType.TEXT_PLAIN)
public String helloWithPath(@PathParam("param") String param) {
return "Hello from Quarkus REST " + param;
}
}
Try It Out
Start your app:
./mvnw quarkus:dev
Test with Normal Requests
curl http://localhost:8080/hello
curl "http://localhost:8080/hello?name=Quarkus"
Simulate Attacks
SQL Injection:
curl "http://localhost:8080/hello?name=tester'%20union%20select%201--"
XSS:
curl "http://localhost:8080/hello?name=<script>alert('XSS')</script>"
Encoded Path Traversal:
curl "http://localhost:8080/hello/%2e%2e%2f%2e%2e%2fetc%2fpasswd"
Rate Limiting:
for i in {1..110}; do curl -s -o /dev/null http://localhost:8080/hello; done
Monitor Stats
curl http://localhost:8080/management/waf/stats | jq
You’ll see live updates on the training buffer, number of clients, audit logs, and training cycles.
What’s Next?
You’ve built a PoC of a WAF core. Now expand it:
Persist your model using
ModelSerializer.writeModel()
and reload it on startupAdd body inspection for POST/PUT payloads
Integrate with ELK or Prometheus for visualization
Use Redis or Kafka to share traffic stats across a cluster
Create a feedback loop where users can report false positives via the frontend
This isn't just anomaly detection, it's a showcase for adaptive, ML-powered security baked into your Quarkus microservices.
You're no longer blocking what attackers did yesterday. You're defending against what they might try tomorrow.