Example 32
advanced
32
Module
Serialization
Neural Networks

Neural Network Module System

The Module base class is the foundation of Deepbox's neural network system. This example demonstrates building custom Module subclasses, enumerating parameters, serializing with stateDict()/loadStateDict(), train/eval modes, and freeze/unfreeze.

Deepbox Modules Used

deepbox/ndarraydeepbox/nn

What You Will Learn

  • Extend Module and call registerModule() for sub-layers
  • .parameters() yields all learnable tensors recursively
  • stateDict()/loadStateDict() for model serialization and checkpointing
  • .train()/.eval() toggle mode — affects Dropout and BatchNorm
  • .freezeParameters() stops gradient tracking — for transfer learning

Source Code

32-module-system/index.ts
1import { GradTensor, parameter, type Tensor, tensor } from "deepbox/ndarray";2import { Linear, Module, ReLU, Sequential } from "deepbox/nn";34console.log("=== Neural Network Module System ===\n");56// ---------------------------------------------------------------------------7// Part 1: Custom Module with parameter registration8// ---------------------------------------------------------------------------9console.log("--- Part 1: Custom Module ---");1011class MyNet extends Module {12  fc1: Linear;13  relu: ReLU;14  fc2: Linear;1516  constructor(inputDim: number, hiddenDim: number, outputDim: number) {17    super();18    this.fc1 = new Linear(inputDim, hiddenDim);19    this.relu = new ReLU();20    this.fc2 = new Linear(hiddenDim, outputDim);21    this.registerModule("fc1", this.fc1);22    this.registerModule("relu", this.relu);23    this.registerModule("fc2", this.fc2);24  }2526  override forward(x: GradTensor): GradTensor;27  override forward(x: Tensor): Tensor;28  override forward(x: Tensor | GradTensor): Tensor | GradTensor {29    if (x instanceof GradTensor) {30      let out: GradTensor = this.fc1.forward(x);31      out = this.relu.forward(out);32      return this.fc2.forward(out);33    }34    let out: Tensor = this.fc1.forward(x);35    out = this.relu.forward(out);36    return this.fc2.forward(out);37  }38}3940const net = new MyNet(4, 8, 2);41console.log("MyNet(4 -> 8 -> 2)");4243// ---------------------------------------------------------------------------44// Part 2: Parameter enumeration45// ---------------------------------------------------------------------------46console.log("\n--- Part 2: Parameters ---");4748const params = Array.from(net.parameters());49console.log(`Total parameter tensors: ${params.length}`);50for (const p of params) {51  const t = p instanceof GradTensor ? p.tensor : p;52  console.log(`  Shape: [${t.shape.join(", ")}]`);53}5455// ---------------------------------------------------------------------------56// Part 3: State dict — serialization & loading57// ---------------------------------------------------------------------------58console.log("\n--- Part 3: State Dict ---");5960const stateDict = net.stateDict();61console.log("State dict parameter keys:");62for (const key of Object.keys(stateDict.parameters)) {63  console.log(`  ${key}`);64}6566// Load state dict back (e.g., from a saved checkpoint)67net.loadStateDict(stateDict);68console.log("State dict loaded successfully");6970// ---------------------------------------------------------------------------71// Part 4: Train/Eval mode72// ---------------------------------------------------------------------------73console.log("\n--- Part 4: Train/Eval Mode ---");7475net.train();76console.log(`Training mode: ${net.training}`);7778net.eval();79console.log(`Eval mode: ${net.training}`);80console.log("  Eval mode disables dropout and uses running stats for batchnorm");8182// ---------------------------------------------------------------------------83// Part 5: Freeze/Unfreeze parameters84// ---------------------------------------------------------------------------85console.log("\n--- Part 5: Freeze/Unfreeze ---");8687net.freezeParameters();88console.log("After freezeParameters:");89const frozenParams = Array.from(net.parameters());90const frozenGrads = frozenParams.filter((p) => p instanceof GradTensor && p.requiresGrad);91console.log(`  Parameters requiring grad: ${frozenGrads.length}`);9293net.unfreezeParameters();94console.log("After unfreezeParameters:");95const unfrozenParams = Array.from(net.parameters());96const unfrozenGrads = unfrozenParams.filter((p) => p instanceof GradTensor && p.requiresGrad);97console.log(`  Parameters requiring grad: ${unfrozenGrads.length}`);9899// ---------------------------------------------------------------------------100// Part 6: Sequential container101// ---------------------------------------------------------------------------102console.log("\n--- Part 6: Sequential Container ---");103104const seqModel = new Sequential(new Linear(4, 8), new ReLU(), new Linear(8, 2));105106console.log("Sequential(Linear(4,8), ReLU, Linear(8,2))");107const seqParams = Array.from(seqModel.parameters()).length;108console.log(`Parameters: ${seqParams}`);109110// Forward pass with plain Tensor (inference)111const input = tensor([[1, 2, 3, 4]]);112const output = seqModel.forward(input);113const outTensor = output instanceof GradTensor ? output.tensor : output;114console.log(`Input shape:  [${input.shape.join(", ")}]`);115console.log(`Output shape: [${outTensor.shape.join(", ")}]`);116117// Forward pass with GradTensor (training)118const gradInput = parameter([[1, 2, 3, 4]]);119const gradOutput = seqModel.forward(gradInput);120console.log(121  `GradTensor output requiresGrad: ${gradOutput instanceof GradTensor ? gradOutput.requiresGrad : false}`122);123124console.log("\n=== Module System Complete ===");

Console Output

$ npx tsx 32-module-system/index.ts
=== Neural Network Module System ===

--- Part 1: Custom Module ---
MyNet(4 -> 8 -> 2)

--- Part 2: Parameters ---
Total parameter tensors: 4
  Shape: [8, 4]
  Shape: [8]
  Shape: [2, 8]
  Shape: [2]

--- Part 3: State Dict ---
State dict parameter keys:
  fc1.weight
  fc1.bias
  fc2.weight
  fc2.bias
State dict loaded successfully

--- Part 4: Train/Eval Mode ---
Training mode: true
Eval mode: false
  Eval mode disables dropout and uses running stats for batchnorm

--- Part 5: Freeze/Unfreeze ---
After freezeParameters:
  Parameters requiring grad: 0
After unfreezeParameters:
  Parameters requiring grad: 4

--- Part 6: Sequential Container ---
Sequential(Linear(4,8), ReLU, Linear(8,2))
Parameters: 4
Input shape:  [1, 4]
Output shape: [1, 2]
GradTensor output requiresGrad: true

=== Module System Complete ===