32
Module
Serialization
Neural Networks
Neural Network Module System
The Module base class is the foundation of Deepbox's neural network system. This example demonstrates building custom Module subclasses, enumerating parameters, serializing with stateDict()/loadStateDict(), train/eval modes, and freeze/unfreeze.
Deepbox Modules Used
deepbox/ndarraydeepbox/nnWhat You Will Learn
- Extend Module and call registerModule() for sub-layers
- .parameters() yields all learnable tensors recursively
- stateDict()/loadStateDict() for model serialization and checkpointing
- .train()/.eval() toggle mode — affects Dropout and BatchNorm
- .freezeParameters() stops gradient tracking — for transfer learning
Source Code
32-module-system/index.ts
1import { GradTensor, parameter, type Tensor, tensor } from "deepbox/ndarray";2import { Linear, Module, ReLU, Sequential } from "deepbox/nn";34console.log("=== Neural Network Module System ===\n");56// ---------------------------------------------------------------------------7// Part 1: Custom Module with parameter registration8// ---------------------------------------------------------------------------9console.log("--- Part 1: Custom Module ---");1011class MyNet extends Module {12 fc1: Linear;13 relu: ReLU;14 fc2: Linear;1516 constructor(inputDim: number, hiddenDim: number, outputDim: number) {17 super();18 this.fc1 = new Linear(inputDim, hiddenDim);19 this.relu = new ReLU();20 this.fc2 = new Linear(hiddenDim, outputDim);21 this.registerModule("fc1", this.fc1);22 this.registerModule("relu", this.relu);23 this.registerModule("fc2", this.fc2);24 }2526 override forward(x: GradTensor): GradTensor;27 override forward(x: Tensor): Tensor;28 override forward(x: Tensor | GradTensor): Tensor | GradTensor {29 if (x instanceof GradTensor) {30 let out: GradTensor = this.fc1.forward(x);31 out = this.relu.forward(out);32 return this.fc2.forward(out);33 }34 let out: Tensor = this.fc1.forward(x);35 out = this.relu.forward(out);36 return this.fc2.forward(out);37 }38}3940const net = new MyNet(4, 8, 2);41console.log("MyNet(4 -> 8 -> 2)");4243// ---------------------------------------------------------------------------44// Part 2: Parameter enumeration45// ---------------------------------------------------------------------------46console.log("\n--- Part 2: Parameters ---");4748const params = Array.from(net.parameters());49console.log(`Total parameter tensors: ${params.length}`);50for (const p of params) {51 const t = p instanceof GradTensor ? p.tensor : p;52 console.log(` Shape: [${t.shape.join(", ")}]`);53}5455// ---------------------------------------------------------------------------56// Part 3: State dict — serialization & loading57// ---------------------------------------------------------------------------58console.log("\n--- Part 3: State Dict ---");5960const stateDict = net.stateDict();61console.log("State dict parameter keys:");62for (const key of Object.keys(stateDict.parameters)) {63 console.log(` ${key}`);64}6566// Load state dict back (e.g., from a saved checkpoint)67net.loadStateDict(stateDict);68console.log("State dict loaded successfully");6970// ---------------------------------------------------------------------------71// Part 4: Train/Eval mode72// ---------------------------------------------------------------------------73console.log("\n--- Part 4: Train/Eval Mode ---");7475net.train();76console.log(`Training mode: ${net.training}`);7778net.eval();79console.log(`Eval mode: ${net.training}`);80console.log(" Eval mode disables dropout and uses running stats for batchnorm");8182// ---------------------------------------------------------------------------83// Part 5: Freeze/Unfreeze parameters84// ---------------------------------------------------------------------------85console.log("\n--- Part 5: Freeze/Unfreeze ---");8687net.freezeParameters();88console.log("After freezeParameters:");89const frozenParams = Array.from(net.parameters());90const frozenGrads = frozenParams.filter((p) => p instanceof GradTensor && p.requiresGrad);91console.log(` Parameters requiring grad: ${frozenGrads.length}`);9293net.unfreezeParameters();94console.log("After unfreezeParameters:");95const unfrozenParams = Array.from(net.parameters());96const unfrozenGrads = unfrozenParams.filter((p) => p instanceof GradTensor && p.requiresGrad);97console.log(` Parameters requiring grad: ${unfrozenGrads.length}`);9899// ---------------------------------------------------------------------------100// Part 6: Sequential container101// ---------------------------------------------------------------------------102console.log("\n--- Part 6: Sequential Container ---");103104const seqModel = new Sequential(new Linear(4, 8), new ReLU(), new Linear(8, 2));105106console.log("Sequential(Linear(4,8), ReLU, Linear(8,2))");107const seqParams = Array.from(seqModel.parameters()).length;108console.log(`Parameters: ${seqParams}`);109110// Forward pass with plain Tensor (inference)111const input = tensor([[1, 2, 3, 4]]);112const output = seqModel.forward(input);113const outTensor = output instanceof GradTensor ? output.tensor : output;114console.log(`Input shape: [${input.shape.join(", ")}]`);115console.log(`Output shape: [${outTensor.shape.join(", ")}]`);116117// Forward pass with GradTensor (training)118const gradInput = parameter([[1, 2, 3, 4]]);119const gradOutput = seqModel.forward(gradInput);120console.log(121 `GradTensor output requiresGrad: ${gradOutput instanceof GradTensor ? gradOutput.requiresGrad : false}`122);123124console.log("\n=== Module System Complete ===");Console Output
$ npx tsx 32-module-system/index.ts
=== Neural Network Module System ===
--- Part 1: Custom Module ---
MyNet(4 -> 8 -> 2)
--- Part 2: Parameters ---
Total parameter tensors: 4
Shape: [8, 4]
Shape: [8]
Shape: [2, 8]
Shape: [2]
--- Part 3: State Dict ---
State dict parameter keys:
fc1.weight
fc1.bias
fc2.weight
fc2.bias
State dict loaded successfully
--- Part 4: Train/Eval Mode ---
Training mode: true
Eval mode: false
Eval mode disables dropout and uses running stats for batchnorm
--- Part 5: Freeze/Unfreeze ---
After freezeParameters:
Parameters requiring grad: 0
After unfreezeParameters:
Parameters requiring grad: 4
--- Part 6: Sequential Container ---
Sequential(Linear(4,8), ReLU, Linear(8,2))
Parameters: 4
Input shape: [1, 4]
Output shape: [1, 2]
GradTensor output requiresGrad: true
=== Module System Complete ===