Optimizing data pipelines for large-scale geospatial machine learning tasks


Table of Contents


The Challenge

Training a machine learning model on a single coordinate’s weather data from a 162TB dataset distributed across 44 zarr batches. Each batch contained ~3,930 features across multiple time dimensions.

Goal: Extract, process, and train without running out of memory or time.

Constraints:

  • 96 CPU cores available
  • 768GB RAM
  • Data stored in Google Cloud Storage (GCS)
  • Local SSD for caching

Key Learnings

1. The Hidden Zarr Compression Bottleneck 🐌

The Problem

Our initial pipeline took 4+ hours to load data, despite having 96 CPU cores available.

The Root Cause

Zarr files compressed with blosc(nthreads=1) force single-threaded decompression. No amount of dask configuration or parallel workers can fix this.

What We Learned

  • ❌ Zarr compression codec settings are permanent - they’re baked into the files
  • ❌ Single-threaded decompression = 1 core working while 95 cores sit idle
  • ❌ Loading 10 file_ids took 10-15 minutes per batch, regardless of available cores
  • ✅ This is the #1 performance bottleneck in our pipeline

The Fix (for future datasets)

When creating zarr files, use multi-threaded compression:

blosc(cname='zstd', clevel=3, shuffle=2, nthreads=8)

Expected improvement: 8× speedup on decompression


2. Dask Configuration Doesn’t Always Help ⚙️

What We Tried

dask.config.set(scheduler='threads', num_workers=32)
coord_data.compute(scheduler='threads', num_workers=32)

What Happened

Nothing. Still single-threaded. CPU usage remained at 100% (1 core).

The Learning

  • Dask parallelizes computation graphs, not I/O-bound codec operations
  • If the bottleneck is in native C libraries (like blosc), dask can’t help
  • Always profile to find the actual bottleneck before optimizing

When Dask DOES Help

  • ✅ Loading multiple files in parallel
  • ✅ CPU-intensive transformations on lazy arrays
  • ✅ Rechunking operations
  • ✅ Distributed computing across multiple machines

When Dask DOESN’T Help

  • ❌ Single-threaded codec decompression
  • ❌ Sequential I/O operations
  • ❌ Operations already optimized in C/Fortran

3. Batch Size Matters More Than You Think 📦

The Evolution

Approach Batches Files per Batch Result
Initial 23 3 Too much overhead
Optimized 7 10 Sweet spot
Aggressive 1 69 Risk of OOM

The Impact

  • 67% fewer .compute() calls (23 → 7 batches)
  • Each call has overhead (lazy graph construction, memory allocation)
  • Total time reduced from ~3 hours to ~1.5 hours

The Rule of Thumb

  • Too small: Death by a thousand cuts (overhead dominates)
  • Too large: RAM exhaustion and crashes
  • Sweet spot: Size batches to use 10-20% of available RAM

How to Calculate Optimal Batch Size

1. Check single file size: ~15GB per file_id
2. Available RAM: 768GB
3. Target usage: ~20% = 150GB
4. Batch size: 150GB / 15GB = 10 files per batch

4. Parallelize the Right Thing 🎯

The Breakthrough

Separating I/O from CPU work

Before: Everything Serial

Load data (single-threaded, 600s)
  ↓
Extract features (serial loop, 600s)
  ↓
Total: ~1,200 seconds/batch

After: Strategic Parallelization

Load data (single-threaded, 600s)  ← Can't parallelize (codec limitation)
  ↓
Extract features (32 workers, 15s)  ← Highly parallelized
  ↓
Total: ~615 seconds/batch

The Learning

I/O-bound operations:

  • ❌ Can’t parallelize: zarr decompression, disk reads, network I/O
  • Accept the limitation and optimize elsewhere

CPU-bound operations:

  • ✅ Highly parallelizable: feature extraction, transformations, calculations
  • Use ThreadPoolExecutor for already-loaded data
  • Scales linearly with cores (up to a point)

Time Saved

40% reduction in total time by parallelizing just the extraction step


5. Memory-Efficient Processing Strategies 💾

Three Approaches We Tested

❌ Approach 1: Load All at Once
Load entire 336GB → Extract features → Train

Pros:

  • Fastest IF you have enough RAM
  • Simplest code

Cons:

  • High risk of OOM (Out of Memory) errors
  • RAM usage spikes during load
  • No fault tolerance

❌ Approach 2: Tiny Batches
23 batches of 3 file_ids each

Pros:

  • Very safe on memory
  • Fault tolerance (can resume)

Cons:

  • Death by overhead (23× the setup cost)
  • Much slower overall

✅ Approach 3: Balanced Batches
7 batches of 10 file_ids each

Pros:

  • Peak RAM: ~220GB (safe on 768GB machine)
  • Minimal overhead (only 7 batches)
  • Good balance of speed and safety

Cons:

  • Slightly more complex logic

Memory Usage Progression

Stage RAM Used Notes
Initial 10GB Script loaded, data lazy
Batch 1 loading 85GB Decompressing zarr
Batch 1 extraction 75GB Features being written
After cleanup 50GB Old batch freed
Batch 2 loading 85GB Process repeats
Final matrices 220GB All features in memory

The Takeaway

Profile your RAM usage and find the largest batch size that fits comfortably. Leave 20-30% headroom for OS and other processes.


6. ProcessPoolExecutor Pitfalls 🕳️

The Mystery

Worker processes would hang indefinitely with no error message. No logs, no timeout, no crash - just eternal silence.

The Investigation

Batch 5: Starting...
Batch 5: Opened zarr stores...
Batch 5: Writing to cache...
[HANGS FOREVER - NO OUTPUT]

The Culprit

  • Zarr’s to_zarr() with compute=True in subprocess
  • Some internal threading conflict between zarr, blosc, and multiprocessing
  • No timeout = no way to detect the hang

The Solution

Always set timeouts:

future.result(timeout=1800)  # 30 minutes max per batch

Use as_completed() to handle results:

for future in as_completed(futures):
    try:
        result = future.result(timeout=TIMEOUT)
        # Process result
    except TimeoutError:
        # Handle timeout
        print(f"Batch {idx} timed out")

For I/O heavy tasks, consider alternatives:

  • ThreadPoolExecutor (if GIL isn’t an issue)
  • Sequential processing with good progress bars
  • Async I/O with asyncio

The Learning

Multiprocessing + complex I/O libraries = potential deadlocks. Always add timeouts and monitoring. Better to fail fast than hang forever.


7. Caching is Your Best Friend 💰

The Strategy

1. Download each batch once to local SSD
2. Cache with coordinate-specific directory names
3. Check cache before re-downloading
4. Validate cached files before trusting them

The Implementation

CACHE_DIR = f"/local_ssd/batch_cache_uk_{lat}_{lon}"

# Check cache first
if os.path.exists(f"{CACHE_DIR}/batch_{i}.zarr"):
    return load_from_cache()

# Download if missing
download_batch()
save_to_cache()

The Payoff

Run Type Time Notes
First run ~2 hours Download all batches from GCS
Second run ~4 minutes Validate cache only
Nth run ~4 minutes Same coordinate
Different coordinate ~2 hours New cache directory

Speed improvement: 30× faster for subsequent runs

Best Practices

  1. Include parameters in cache name
    batch_cache_uk_51.5_0.0  ← Good (coordinate-specific)
    batch_cache              ← Bad (conflicts)
    
  2. Validate cache files
    try:
        ds = xr.open_zarr(cache_path)
        if len(ds.data_vars) == 0:
            return False  # Corrupted
    except:
        return False  # Invalid
    
  3. Use local SSD if available
    • NVMe SSD: 3-7 GB/s read speed
    • Network storage: 100-500 MB/s
    • 10-70× faster for local SSD
  4. Implement cache cleanup strategy
    • Set cache size limits
    • Use LRU (Least Recently Used) eviction
    • Clean up on successful completion (optional)

8. Data Integrity Issues 🔍

The Surprise

After concatenating 22 batches, we expected 95 unique file_ids but found only 69.

The Diagnosis

Batch 0: file_ids [0, 1, 2, 3]
Batch 1: file_ids [0, 1, 2, 3]  ← DUPLICATE!
Batch 2: file_ids [4, 5, 6]
...

Result: 26 duplicate file_ids across batches

The Impact

  • Concatenating along file_id created duplicate rows
  • Silent data corruption (no error raised)
  • Models would train on inflated datasets
  • Results would be biased toward duplicated time periods

The Fix

# 1. Detect duplicates
unique_file_ids, unique_indices = np.unique(
    dataset.file_id.values, 
    return_index=True
)

# 2. Count them
n_duplicates = len(dataset.file_id) - len(unique_file_ids)

# 3. Remove duplicates (keep first occurrence)
if n_duplicates > 0:
    unique_indices = np.sort(unique_indices)
    dataset = dataset.isel(file_id=unique_indices)

Prevention Strategies

  1. Validate during data creation
    • Ensure batches have non-overlapping file_ids
    • Add assertions in batch creation code
  2. Check after concatenation
    • Always validate dimension uniqueness
    • Log warnings for duplicates
  3. Add metadata
    • Include file_id ranges in batch metadata
    • Document expected ranges
  4. Data source alignment
    • Verify both data sources have same file_id scheme
    • Check if paths should be merged vs concatenated

The Lesson

Never trust data integrity implicitly. Always validate dimension uniqueness after concatenation, especially with multiple data sources.


9. When More Cores Don’t Help 🚫

The Question

“We have 96 cores. Why not use 64 or 96 workers instead of 32?”

The Reality Check

Time breakdown per batch:

  • Load time: 600 seconds (98% of time, single-threaded)
  • Extract time: 15 seconds (2% of time, 32 workers)

Performance Analysis

Workers Extract Time Improvement % of Total
1 (serial) 600s baseline -
8 75s 525s saved Significant
16 38s 562s saved Good
32 15s 585s saved Excellent
64 10s 590s saved Minimal gain
96 8s 592s saved Minimal gain

Increasing Workers to 96

Time saved: 7 seconds per batch Total saved (7 batches): 49 seconds out of 5,000 seconds total Improvement: <1%

The Hidden Costs

  • Thread creation overhead
  • Context switching
  • Memory contention
  • GIL (Global Interpreter Lock) contention
  • Diminishing returns after ~CPU_count/2

The Principle

Optimize the bottleneck, not the fast parts.

Where to focus:

  1. ✅ Fix the 600s load time (98% of time)
  2. ❌ Not the 15s extract time (2% of time)

Amdahl’s Law in Practice

If 98% of time is serial (load):
Maximum speedup with infinite parallel cores = 1.02×

If you could parallelize the load time:
Maximum speedup = 50×

Lesson: Attack the bottleneck, not everything else.


10. Diagnostic Tools Are Essential 🔧

The Toolkit

1. top - Real-time System Monitor
top -u username

What to watch:

  • %CPU: Is parallelism working?
    • 100% = 1 core active
    • 3200% = 32 cores active
    • 9600% = 96 cores active (theoretical max)
  • RES: Actual RAM usage (watch for memory leaks)
  • TIME+: How long the process has been running
  • S (State): R=Running, S=Sleeping, D=Disk wait
2. htop - Interactive Process Viewer
htop

Benefits:

  • See all cores individually (visual bars)
  • Identify if work is distributed or concentrated
  • Sort by CPU, memory, time
  • Tree view for process hierarchy
3. Strategic print() Statements
import time

start = time.time()
print(f"Starting batch {i}...", flush=True)

# Do work...

elapsed = time.time() - start
print(f"Completed in {elapsed:.1f}s", flush=True)

Key additions:

  • flush=True - Force immediate output (critical for long operations)
  • Timestamps for each operation
  • Memory usage after major operations
  • Progress indicators for long loops
4. Memory Profiling
import psutil
import os

def log_memory():
    process = psutil.Process(os.getpid())
    mem_gb = process.memory_info().rss / (1024**3)
    print(f"Memory: {mem_gb:.1f}GB")

# Call at key points
log_memory()
5. Python Profilers

For CPU profiling:

python -m cProfile -o output.prof script.py
python -m pstats output.prof

For memory profiling:

pip install memory_profiler
python -m memory_profiler script.py
6. Zarr Store Inspector
import zarr

# Check compression settings
z = zarr.open('path/to/store.zarr')
print(z.info)  # Shows codec, chunks, dtype

# Check compressor
for key in z.arrays():
    arr = z[key]
    print(f"{key}: {arr.compressor}")

The Learning

You can’t optimize what you can’t measure.

Best practices:

  1. Instrument early and often
  2. Log timestamps at every major operation
  3. Monitor resource usage in real-time
  4. Profile before optimizing
  5. Measure improvements quantitatively

11. Throughput Beats Per-Operation Speed 🎯

The Dilemma

You have 96 cores. Should you:

  • Option A: 8 workers × 12 threads = 96 total threads
  • Option B: 16 workers × 6 threads = 96 total threads

Most people choose A, thinking “more threads per operation = faster operations.”

The Reality Check

What actually takes time:

Phase Time per Batch Affected by Threads?
Network download 500-700s ❌ No
Decompress source 600-900s ❌ No (source codec fixed)
Compress output 50-100s ✅ Yes
Total ~1,500s Only ~5%

You’re only optimizing 5% of the operation!

The Comparison

Option A: 8 workers × 12 threads

Individual compression: 50s (fast)
Total batches to process: 22
Can process in parallel: 8 at a time
Rounds needed: 22 ÷ 8 = 2.75 rounds
Total time: 1,500s × 2.75 = 4,125s (69 minutes)

Option B: 16 workers × 6 threads

Individual compression: 75s (1.5× slower)
Total batches to process: 22
Can process in parallel: 16 at a time
Rounds needed: 22 ÷ 16 = 1.4 rounds
Total time: 1,550s × 1.4 = 2,131s (36 minutes)

Result: Option B is 48% faster overall, despite slower individual compression!

The Principle

Maximize throughput (operations/time), not operation speed.

Throughput Formula:

Throughput = (Number of Workers) / (Time per Operation)

Option A: 8 / 1,500s = 0.0053 batches/second
Option B: 16 / 1,550s = 0.0103 batches/second

Option B has 94% higher throughput!

The CPU Utilization View

Option A: 8 workers × 12 threads

Worker utilization over time:
████████░░░░░░░░░░░░░░░░░░░░  (8/16 potential workers active)
                               (~50% worker utilization)
Cores active: ~96 (but only 8 batches being processed)

Option B: 16 workers × 6 threads

Worker utilization over time:
████████████████████████████  (16/16 potential workers active)
                               (100% worker utilization)
Cores active: ~96 (16 batches being processed)

Even with the same number of active cores, more workers = better throughput.

When This Applies

This principle works when:

  • ✅ You have a queue of independent operations
  • ✅ The bottleneck is NOT the parallelized part (only 5% of time)
  • ✅ You have excess capacity (cores/RAM)

This doesn’t apply when:

  • ❌ You have a single large operation (no queue)
  • ❌ The parallelized part IS the bottleneck (98% of time)
  • ❌ You’re at resource limits

The Takeaway

When you have a queue of operations with a fixed bottleneck:

Don’t optimize: The 5% that’s already fast Do optimize: Process more items in parallel

The 3% slowdown per batch is completely dominated by the 2× speedup from parallel processing.

It’s not about how fast each batch completes - it’s about how many batches complete per hour.


12. The Dask Lazy Loading Trap 🪤

The Nightmare Scenario

You spend 90 minutes downloading 300GB of data from Google Cloud Storage, carefully cache it to local SSD, and breathe a sigh of relief. The next run loads in seconds. Perfect!

Then you start training and discover your data is 100% NaN. Hours of work, completely wasted.

What happened?

The Discovery

When we checked our cache directory, something was very wrong:

Downloaded: 22 batches × ~13GB each = 286GB expected
Actual cache size: 0.4GB (18MB per batch!)

18MB is metadata size, not data. Our “cache” contained pointers back to the source files, not the actual data.

The Root Cause: .compute() vs .load()

The bug was subtle:

coord_slice = merged.sel(lat=51.5, lon=0.0, method="nearest")
coord_slice_computed = coord_slice.compute()  # ❌ 
coord_slice_computed.to_zarr("cache.zarr")

We thought .compute() would force everything into memory as numpy arrays. It didn’t.

What .compute() does:

  • Executes the dask computation graph
  • Returns the computed result
  • But the result may still be wrapped in dask arrays

What .load() does:

  • Loads data into memory IN-PLACE
  • Forces conversion to numpy arrays
  • Modifies the dataset directly

The difference is critical for caching.

Why This Failed Silently

When xarray writes a dataset to zarr, it checks each array:

if isinstance(array, dask.array.Array):
    # Write as lazy dask (metadata + chunk references)
else:
    # Write as numpy (actual data)

Our “computed” dataset still contained dask arrays, so xarray wrote metadata and references to the original GCS files. The cache was just storing instructions like “load chunk 0 from gs://bucket/batch_000/…” instead of actual values.

The result: Every time we “loaded from cache,” we were actually fetching from GCS again, using the wrong coordinates, getting 100% NaN.

The Fix

One word change:

# WRONG:
coord_slice_computed = coord_slice.compute()

# RIGHT:
coord_slice_loaded = coord_slice.load()

That’s it. .load() guarantees numpy arrays.

The Validation Strategy

The real lesson: verify immediately after caching.

We added a simple check after each batch:

# After writing cache
ds = xr.open_zarr(cache_path)

# Check 1: Is it actually numpy?
assert 'dask' not in type(ds['temperature'].data).__name__

# Check 2: Is the file size reasonable?
size_gb = get_directory_size(cache_path) / (1024**3)
assert size_gb > 5.0, f"Cache suspiciously small: {size_gb}GB"

# Check 3: Can we access actual values?
sample = ds['temperature'].values[0, 0, 0, 0]
assert not np.isnan(sample), "Sample data is NaN"

With validation in place, we would have caught the bug on the first batch (in 4 minutes) instead of after all 22 batches (90 minutes).

Why This Is So Dangerous

Silent failure:

  • No errors or exceptions raised
  • Files are created successfully
  • Everything “looks” fine
  • Cache appears to work (loads quickly)
  • Only fails later when you check the actual data

Devastating consequences:

  • Hours of downloading wasted
  • Training pipeline fails mysteriously
  • Hard to diagnose (looks like data quality issue, not caching bug)
  • High restart cost (must redownload everything)

The smoking gun: If your cache is suspiciously small, you have lazy references, not data.

Real-World Impact

Time lost on this bug:

  • 90 minutes: Initial download
  • 60 minutes: “Loading” from cache (actually re-fetching from GCS)
  • 120 minutes: Debugging why data is 100% NaN
  • 90 minutes: Re-downloading with fix

Total: 6 hours lost to one word

The Documentation Gap

Xarray’s documentation says .compute() will “load lazy arrays into memory.” This is technically true but dangerously misleading. The arrays are computed, but not necessarily converted to numpy.

The docs don’t emphasize that for caching, you must use .load().

The Broader Pattern

This isn’t just about xarray. The lesson applies anywhere lazy evaluation meets caching:

Dask dataframes:

df.compute().to_parquet("cache.parquet")  # Might still be lazy!
df.persist().to_parquet("cache.parquet")  # Forces materialization

Tensorflow datasets:

ds.cache()  # Might cache lazy ops
ds.cache().prefetch()  # Forces materialization

Spark:

df.cache()  # Lazy
df.persist()  # Forces storage

The pattern repeats: lazy evaluation frameworks have subtle differences between “compute” and “materialize.”

How to Avoid This

1. Use the right loading method

  • Xarray: .load() not .compute()
  • Understand your framework’s materialization semantics

2. Validate immediately

  • Check file sizes match expectations
  • Verify data types (numpy vs dask)
  • Sample actual values
  • Fail fast if validation fails

3. Test before scaling

  • Cache one batch first
  • Verify it thoroughly
  • Only then cache all batches

4. Trust but verify

  • Don’t assume success
  • Add assertions liberally
  • Monitor disk space during operations

The Takeaway

Lazy evaluation + caching = careful validation required

One character difference (.compute() vs .load()) cost us 6 hours. The bug was silent, the symptoms were delayed, and the diagnosis was difficult.

The antidote: Validate aggressively. Check your cache immediately after creation. If file sizes don’t match expectations, something is wrong.

In distributed data pipelines, the most expensive bugs are the ones that fail silently and succeed loudly.


Final Performance Summary

Metrics Comparison

Metric Initial Naive Final Optimized Improvement
Batch count 23 7 3.3× fewer
Files per batch 3 10 3.3× larger
Feature extraction Serial (1 core) 32 parallel workers 32× faster
Load time/batch 250s 750s - (limited by codec)
Extract time/batch 600s 15s 40× faster
Total time/batch 850s 765s 1.1× faster
Total batches 23 batches 7 batches -
Total pipeline ~5.5 hours ~1.5 hours 3.7× faster

Key Optimizations Impact

Optimization Time Saved Impact
Fewer batches (23→7) ~1.5 hours High
Parallel extraction ~2.0 hours Very High
Efficient caching Infinite (subsequent runs) Critical
Better batch sizing ~0.5 hours Medium

What We Couldn’t Fix

  • Zarr decompression bottleneck: Limited by codec settings
    • Current: 600-900s per batch (single-threaded)
    • Potential: 75-110s per batch (with multi-threaded codec)
    • Requires re-creating zarr files

Takeaways for Your Next Data Pipeline

Top 12 Lessons

  1. 🔍 Profile first, optimize second
    • Find the actual bottleneck with top, htop, profilers
    • Don’t assume - measure
  2. 💾 Compression settings matter
    • Multi-threaded codecs save hours
    • Settings are permanent - choose wisely
    • Test decompression speed during prototyping
  3. 📦 Batch size is critical
    • Balance overhead vs memory
    • Target 10-20% of available RAM
    • Profile memory usage to find sweet spot
  4. 🎯 Parallelize the right operations
    • CPU-bound work: highly parallelizable
    • I/O-bound work: usually can’t parallelize
    • Focus on CPU-bound bottlenecks
  5. 💰 Cache aggressively
    • Local SSD storage is cheap
    • 30× speedup for repeated operations
    • Include parameters in cache keys
  6. 🔧 Add timeouts everywhere
    • Prevent silent hangs
    • Fail fast, not never
    • Always use future.result(timeout=X)
  7. ✅ Validate data integrity
    • Check for duplicates after merging
    • Verify dimension uniqueness
    • Add assertions liberally
  8. 📊 Monitor in real-time
    • Use top/htop to verify assumptions
    • Add logging with timestamps
    • Track memory usage
  9. 🚫 Know when NOT to parallelize
    • Some operations can’t be sped up
    • Thread overhead can make things slower
    • Amdahl’s Law: optimize the bottleneck
  10. ⚙️ Dask isn’t magic
    • Can’t fix codec-level bottlenecks
    • Only helps with computation graphs
    • Profile to see if it actually helps
  11. 🎯 Optimize for throughput, not speed
    • More parallel workers beats faster individual operations
    • When 95% of time is I/O, don’t over-optimize the 5%
    • Queue processing efficiency > per-item latency
  12. 🪤 Watch for lazy evaluation traps
    • Use .load() not .compute() for caching
    • Validate immediately after cache creation
    • Check file sizes and data types
    • Test on one item before scaling to all

The Bottom Line

Working with massive datasets requires understanding the full stack - from compression codecs to multiprocessing semantics to hardware capabilities.

Key Insights

The biggest performance gains often come from:

  • ✅ Fixing fundamental design choices (compression settings)
  • ✅ Strategic caching
  • ✅ Right-sizing batches
  • ✅ Validating assumptions immediately

Not from:

  • ❌ Adding more parallelism everywhere
  • ❌ Complex distributed computing frameworks
  • ❌ Throwing more hardware at the problem

The Golden Rule

The most expensive operation is the one you don’t need to do.

  • Cache it
  • Batch it smartly
  • Parallelize what actually benefits from parallelism
  • Fix the bottleneck, not everything else
  • Validate aggressively

Additional Resources

Tools Mentioned

  • top / htop - System monitoring
  • psutil - Python system utilities
  • memory_profiler - Memory usage profiling
  • cProfile - CPU profiling
  • zarr - Array storage library
  • dask - Parallel computing library
  • xarray - Multi-dimensional labeled arrays