TensorFlow.js 高级应用 #
高级主题概述 #
本章涵盖 TensorFlow.js 的高级应用技巧,帮助你构建生产级机器学习应用。
text
┌─────────────────────────────────────────────────────────────┐
│ 高级应用主题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 自定义组件 │ │ 性能优化 │ │ 调试技巧 │ │
│ │ 层/损失函数 │ │ 内存/速度 │ │ 可视化 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 生产部署 │ │ 安全考虑 │ │ 最佳实践 │ │
│ │ 容器化 │ │ 数据保护 │ │ 代码规范 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
自定义组件 #
自定义层 #
javascript
class CustomDense extends tf.layers.Layer {
constructor(units, config = {}) {
super(config);
this.units = units;
this.activation = config.activation || 'linear';
}
build(inputShape) {
this.kernel = this.addWeight(
'kernel',
[inputShape[inputShape.length - 1], this.units],
'float32',
tf.initializers.glorotNormal()
);
this.bias = this.addWeight(
'bias',
[this.units],
'float32',
tf.initializers.zeros()
);
}
call(inputs) {
const input = inputs instanceof tf.Tensor ? inputs : inputs[0];
let output = input.matMul(this.kernel.read());
output = output.add(this.bias.read());
if (this.activation !== 'linear') {
output = tf.layers.activation({ activation: this.activation }).apply(output);
}
return output;
}
computeOutputShape(inputShape) {
return [...inputShape.slice(0, -1), this.units];
}
static get className() {
return 'CustomDense';
}
}
tf.serialization.registerClass(CustomDense);
自定义损失函数 #
javascript
function focalLoss(gamma = 2.0, alpha = 0.25) {
return (yTrue, yPred) => {
const epsilon = tf.scalar(1e-7);
const yPredClipped = tf.clipByValue(yPred, epsilon, tf.scalar(1).sub(epsilon));
const crossEntropy = yTrue.mul(yPredClipped.log()).neg();
const weight = yTrue.mul(tf.scalar(alpha))
.add(tf.scalar(1).sub(yTrue).mul(tf.scalar(1 - alpha)));
const focal = weight.mul(tf.scalar(1).sub(yPredClipped).pow(gamma));
return focal.mul(crossEntropy).sum(-1).mean();
};
}
model.compile({
optimizer: 'adam',
loss: focalLoss(2.0, 0.25)
});
自定义激活函数 #
javascript
function swish(x) {
return x.mul(tf.sigmoid(x));
}
class Swish extends tf.layers.Layer {
call(inputs) {
const input = inputs instanceof tf.Tensor ? inputs : inputs[0];
return swish(input);
}
static get className() {
return 'Swish';
}
}
tf.serialization.registerClass(Swish);
自定义正则化器 #
javascript
class L1L2Regularizer extends tf.regularizers.Regularizer {
constructor(l1 = 0.01, l2 = 0.01) {
super();
this.l1 = l1;
this.l2 = l2;
}
apply(x) {
let regularization = tf.scalar(0);
if (this.l1 > 0) {
regularization = regularization.add(
tf.scalar(this.l1).mul(x.abs().sum())
);
}
if (this.l2 > 0) {
regularization = regularization.add(
tf.scalar(this.l2).mul(x.square().sum())
);
}
return regularization;
}
static get className() {
return 'L1L2Regularizer';
}
}
tf.serialization.registerClass(L1L2Regularizer);
自定义优化器 #
javascript
class SGDMomentum extends tf.Optimizer {
constructor(learningRate, momentum = 0.9) {
super();
this.learningRate = learningRate;
this.momentum = momentum;
this.velocities = {};
}
applyGradients(variableGradients) {
for (const varName in variableGradients) {
const gradient = variableGradients[varName];
const variable = this.getVariable(varName);
if (!this.velocities[varName]) {
this.velocities[varName] = tf.zerosLike(gradient);
}
const velocity = this.velocities[varName];
const newVelocity = velocity.mul(this.momentum).add(gradient);
this.velocities[varName] = newVelocity;
variable.assign(variable.sub(newVelocity.mul(this.learningRate)));
}
}
static get className() {
return 'SGDMomentum';
}
}
tf.serialization.registerClass(SGDMomentum);
性能优化 #
内存优化 #
javascript
function optimizedPredict(model, input) {
return tf.tidy(() => {
const tensor = tf.tensor(input);
const output = model.predict(tensor);
return output.dataSync();
});
}
async function batchPredictOptimized(model, inputs, batchSize = 32) {
const results = [];
for (let i = 0; i < inputs.length; i += batchSize) {
const batchResults = tf.tidy(() => {
const batch = inputs.slice(i, i + batchSize);
const batchTensor = tf.stack(batch.map(x => tf.tensor(x)));
const predictions = model.predict(batchTensor);
return predictions.arraySync();
});
results.push(...batchResults);
}
return results;
}
运算融合 #
javascript
function fusedOperations(input) {
return tf.tidy(() => {
return input
.square()
.add(1)
.div(2)
.sqrt();
});
}
预分配内存 #
javascript
class PreallocatedPredictor {
constructor(model, inputShape, outputShape) {
this.model = model;
this.inputBuffer = tf.buffer(inputShape);
this.outputBuffer = tf.buffer(outputShape);
}
predict(data) {
for (let i = 0; i < data.length; i++) {
this.inputBuffer.set(data[i], i);
}
const input = this.inputBuffer.toTensor();
const output = this.model.predict(input);
return output.dataSync();
}
}
异步处理 #
javascript
class AsyncPredictor {
constructor(model) {
this.model = model;
this.queue = [];
this.processing = false;
}
async addToQueue(input) {
return new Promise((resolve) => {
this.queue.push({ input, resolve });
this.processQueue();
});
}
async processQueue() {
if (this.processing || this.queue.length === 0) return;
this.processing = true;
while (this.queue.length > 0) {
const { input, resolve } = this.queue.shift();
const result = await this.predict(input);
resolve(result);
}
this.processing = false;
}
async predict(input) {
return tf.tidy(() => {
const tensor = tf.tensor(input);
const output = this.model.predict(tensor);
return output.dataSync();
});
}
}
调试技巧 #
张量调试 #
javascript
function debugTensor(tensor, name = 'tensor') {
console.log(`=== ${name} ===`);
console.log('Shape:', tensor.shape);
console.log('Dtype:', tensor.dtype);
console.log('Min:', tensor.min().dataSync()[0]);
console.log('Max:', tensor.max().dataSync()[0]);
console.log('Mean:', tensor.mean().dataSync()[0]);
console.log('Std:', tensor.std().dataSync()[0]);
}
梯度检查 #
javascript
function checkGradients(model, x, y) {
const { grads, value } = tf.variableGrads(() => {
const pred = model.predict(x);
return tf.losses.meanSquaredError(y, pred);
});
console.log('Loss:', value.dataSync()[0]);
for (const name in grads) {
const grad = grads[name];
console.log(`${name}:`);
console.log(' Mean:', grad.mean().dataSync()[0]);
console.log(' Std:', grad.std().dataSync()[0]);
console.log(' Max:', grad.max().dataSync()[0]);
console.log(' Min:', grad.min().dataSync()[0]);
}
}
层输出检查 #
javascript
function inspectLayerOutputs(model, input) {
const outputs = [];
for (const layer of model.layers) {
const layerModel = tf.model({
inputs: model.input,
outputs: layer.output
});
const output = layerModel.predict(input);
outputs.push({
name: layer.name,
shape: output.shape,
mean: output.mean().dataSync()[0],
std: output.std().dataSync()[0]
});
output.dispose();
}
return outputs;
}
训练监控 #
javascript
class TrainingMonitor {
constructor() {
this.history = {
loss: [],
accuracy: [],
valLoss: [],
valAccuracy: []
};
}
onEpochEnd(epoch, logs) {
this.history.loss.push(logs.loss);
this.history.accuracy.push(logs.acc);
this.history.valLoss.push(logs.val_loss);
this.history.valAccuracy.push(logs.val_acc);
console.log(`Epoch ${epoch + 1}:`);
console.log(` Loss: ${logs.loss.toFixed(4)}`);
console.log(` Accuracy: ${logs.acc.toFixed(4)}`);
console.log(` Val Loss: ${logs.val_loss.toFixed(4)}`);
console.log(` Val Accuracy: ${logs.val_acc.toFixed(4)}`);
this.checkForIssues(logs);
}
checkForIssues(logs) {
if (isNaN(logs.loss)) {
console.warn('警告: 损失为 NaN!');
}
if (logs.loss > 1e6) {
console.warn('警告: 损失爆炸!');
}
if (this.history.loss.length > 5) {
const recent = this.history.loss.slice(-5);
const increasing = recent.every((v, i) => i === 0 || v > recent[i - 1]);
if (increasing) {
console.warn('警告: 损失持续增加!');
}
}
}
}
生产最佳实践 #
错误处理 #
javascript
class SafePredictor {
constructor(model) {
this.model = model;
}
async predict(input, options = {}) {
const { timeout = 5000, retries = 3 } = options;
for (let attempt = 0; attempt < retries; attempt++) {
try {
const result = await Promise.race([
this._predict(input),
this._timeout(timeout)
]);
return result;
} catch (error) {
console.error(`预测失败 (尝试 ${attempt + 1}/${retries}):`, error);
if (attempt === retries - 1) {
throw new Error(`预测失败: ${error.message}`);
}
await this._delay(100 * (attempt + 1));
}
}
}
async _predict(input) {
return tf.tidy(() => {
const tensor = tf.tensor(input);
const output = this.model.predict(tensor);
return output.dataSync();
});
}
_timeout(ms) {
return new Promise((_, reject) => {
setTimeout(() => reject(new Error('预测超时')), ms);
});
}
_delay(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
}
资源管理 #
javascript
class ResourceManager {
constructor(options = {}) {
this.maxMemory = options.maxMemory || 512 * 1024 * 1024;
this.maxTensors = options.maxTensors || 1000;
}
checkResources() {
const mem = tf.memory();
if (mem.numBytes > this.maxMemory) {
console.warn('内存使用过高,执行清理');
tf.disposeVariables();
}
if (mem.numTensors > this.maxTensors) {
console.warn('张量数量过多,执行清理');
tf.disposeVariables();
}
return mem;
}
wrapOperation(fn) {
return (...args) => {
this.checkResources();
try {
const result = fn(...args);
if (result instanceof Promise) {
return result.finally(() => this.checkResources());
}
return result;
} finally {
this.checkResources();
}
};
}
}
模型版本管理 #
javascript
class ModelVersionManager {
constructor(baseUrl) {
this.baseUrl = baseUrl;
this.models = new Map();
this.currentVersion = null;
}
async loadVersion(version) {
if (this.models.has(version)) {
this.currentVersion = version;
return this.models.get(version);
}
const modelUrl = `${this.baseUrl}/v${version}/model.json`;
const model = await tf.loadLayersModel(modelUrl);
this.models.set(version, model);
this.currentVersion = version;
return model;
}
getCurrentModel() {
if (!this.currentVersion) {
throw new Error('没有加载任何模型版本');
}
return this.models.get(this.currentVersion);
}
async preloadVersions(versions) {
await Promise.all(versions.map(v => this.loadVersion(v)));
}
disposeVersion(version) {
const model = this.models.get(version);
if (model) {
model.dispose();
this.models.delete(version);
}
}
disposeAll() {
for (const model of this.models.values()) {
model.dispose();
}
this.models.clear();
this.currentVersion = null;
}
}
日志记录 #
javascript
class TFLogger {
constructor(level = 'info') {
this.level = level;
this.levels = { debug: 0, info: 1, warn: 2, error: 3 };
}
log(level, message, data = {}) {
if (this.levels[level] >= this.levels[this.level]) {
const timestamp = new Date().toISOString();
console.log(JSON.stringify({
timestamp,
level,
message,
...data
}));
}
}
debug(message, data) { this.log('debug', message, data); }
info(message, data) { this.log('info', message, data); }
warn(message, data) { this.log('warn', message, data); }
error(message, data) { this.log('error', message, data); }
logMemory() {
const mem = tf.memory();
this.info('Memory status', {
tensors: mem.numTensors,
bytes: mem.numBytes,
buffers: mem.numDataBuffers
});
}
}
安全考虑 #
输入验证 #
javascript
function validateInput(input, expectedShape) {
if (!Array.isArray(input)) {
throw new Error('输入必须是数组');
}
const flatInput = flatten(input);
if (flatInput.length !== expectedShape.reduce((a, b) => a * b)) {
throw new Error(`输入形状不匹配: 期望 ${expectedShape}, 得到 ${flatInput.length}`);
}
for (const value of flatInput) {
if (typeof value !== 'number' || isNaN(value)) {
throw new Error('输入包含非数值');
}
if (!isFinite(value)) {
throw new Error('输入包含无穷值');
}
}
return true;
}
function flatten(arr) {
return arr.reduce((acc, val) =>
Array.isArray(val) ? acc.concat(flatten(val)) : acc.concat(val),
[]
);
}
输出限制 #
javascript
function sanitizeOutput(output, options = {}) {
const { maxValue = 1e6, minValue = -1e6 } = options;
return output.map(value => {
if (isNaN(value)) return 0;
if (!isFinite(value)) return value > 0 ? maxValue : minValue;
return Math.max(minValue, Math.min(maxValue, value));
});
}
完整项目模板 #
javascript
const tf = require('@tensorflow/tfjs-node');
class MLService {
constructor(config) {
this.config = config;
this.model = null;
this.logger = new TFLogger(config.logLevel);
this.resourceManager = new ResourceManager(config.resources);
}
async init() {
this.logger.info('初始化服务...');
try {
this.model = await tf.loadLayersModel(this.config.modelUrl);
await this._warmup();
this.logger.info('模型加载完成');
} catch (error) {
this.logger.error('模型加载失败', { error: error.message });
throw error;
}
}
async _warmup() {
const inputShape = this.model.inputShape.slice(1);
const dummy = tf.zeros([1, ...inputShape]);
await this.model.predict(dummy).data();
dummy.dispose();
}
async predict(input) {
return this.resourceManager.wrapOperation(async () => {
validateInput(input, this.model.inputShape.slice(1));
const result = tf.tidy(() => {
const tensor = tf.tensor(input).expandDims(0);
const output = this.model.predict(tensor);
return output.dataSync();
});
return sanitizeOutput(result);
})();
}
async predictBatch(inputs) {
return Promise.all(inputs.map(input => this.predict(input)));
}
getHealth() {
const mem = tf.memory();
return {
status: 'healthy',
modelLoaded: this.model !== null,
memory: {
tensors: mem.numTensors,
bytes: mem.numBytes
}
};
}
dispose() {
if (this.model) {
this.model.dispose();
}
this.logger.info('服务已关闭');
}
}
module.exports = { MLService };
总结 #
TensorFlow.js 是一个强大的机器学习库,通过本系列教程,你已经学习了:
- 基础概念:张量、运算、模型构建
- 核心进阶:层配置、训练方法、优化算法
- 高级应用:CNN、RNN、迁移学习
- 实战部署:浏览器推理、Node.js 部署
继续实践和探索,你将成为 TensorFlow.js 专家!
最后更新:2026-03-29