const fs = require('fs');
const path = require('path');
class LLMTrainer {
constructor(config) {
this.config = {
modelConfig: config.modelConfig || {},
trainingConfig: config.trainingConfig || {},
dataConfig: config.dataConfig || {}
};
this.model = null;
this.optimizer = null;
this.trainingHistory = [];
console.log('LLM训练器初始化完成');
this.printConfig();
}
printConfig() {
console.log('\n=== 训练配置 ===');
console.log('模型配置:', JSON.stringify(this.config.modelConfig, null, 2));
console.log('训练配置:', JSON.stringify(this.config.trainingConfig, null, 2));
console.log('数据配置:', JSON.stringify(this.config.dataConfig, null, 2));
}
preprocessData(rawData, stage = 'pretrain') {
console.log(`\n=== ${stage}数据预处理 ===`);
console.log(`原始数据量: ${rawData.length}`);
let processedData = rawData;
processedData = this.cleanData(processedData);
console.log(`清洗后数据量: ${processedData.length}`);
processedData = this.formatData(processedData, stage);
console.log(`格式化后数据量: ${processedData.length}`);
processedData = this.tokenizeData(processedData);
console.log(`分词后数据量: ${processedData.length}`);
const stats = this.computeDataStats(processedData);
console.log('数据统计:', stats);
return processedData;
}
cleanData(data) {
console.log('执行数据清洗...');
return data.filter(item => {
if (!item || typeof item !== 'string') return false;
if (item.length < 10 || item.length > 2048) return false;
if (this.containsHarmfulContent(item)) return false;
if (this.isDuplicate(item)) return false;
return true;
});
}
containsHarmfulContent(text) {
const harmfulPatterns = [
/\b(hate|violence|illegal)\b/i,
/\b(spam|scam|fraud)\b/i
];
return harmfulPatterns.some(pattern => pattern.test(text));
}
isDuplicate(text) {
if (!this.seenTexts) {
this.seenTexts = new Set();
}
const hash = this.simpleHash(text);
if (this.seenTexts.has(hash)) {
return true;
}
this.seenTexts.add(hash);
return false;
}
simpleHash(text) {
let hash = 0;
for (let i = 0; i < text.length; i++) {
const char = text.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash;
}
return hash;
}
formatData(data, stage) {
console.log(`格式化数据为${stage}格式...`);
switch (stage) {
case 'pretrain':
return this.formatPretrainData(data);
case 'sft':
return this.formatSFTData(data);
case 'rlhf':
return this.formatRLHFData(data);
default:
return data;
}
}
formatPretrainData(data) {
return data.map(text => ({
input: text,
target: text,
type: 'pretrain'
}));
}
formatSFTData(data) {
return data.map(item => {
if (typeof item === 'string') {
return {
instruction: '请继续以下文本:',
input: item.substring(0, item.length / 2),
output: item.substring(item.length / 2),
type: 'sft'
};
}
return item;
});
}
formatRLHFData(data) {
return data.map(item => ({
prompt: item.prompt || '',
chosen: item.chosen || '',
rejected: item.rejected || '',
type: 'rlhf'
}));
}
tokenizeData(data) {
console.log('执行数据分词...');
return data.map(item => {
const tokenized = { ...item };
if (item.input) {
tokenized.input_tokens = this.tokenize(item.input);
}
if (item.target) {
tokenized.target_tokens = this.tokenize(item.target);
}
if (item.output) {
tokenized.output_tokens = this.tokenize(item.output);
}
return tokenized;
});
}
tokenize(text) {
return text.split('').map(char =>
char.charCodeAt(0) % 1000
);
}
computeDataStats(data) {
const stats = {
totalSamples: data.length,
avgInputLength: 0,
avgOutputLength: 0,
maxLength: 0,
minLength: Infinity,
typeDistribution: {}
};
let totalInputLength = 0;
let totalOutputLength = 0;
data.forEach(item => {
const type = item.type || 'unknown';
stats.typeDistribution[type] = (stats.typeDistribution[type] || 0) + 1;
if (item.input_tokens) {
const length = item.input_tokens.length;
totalInputLength += length;
stats.maxLength = Math.max(stats.maxLength, length);
stats.minLength = Math.min(stats.minLength, length);
}
if (item.output_tokens) {
totalOutputLength += item.output_tokens.length;
}
});
stats.avgInputLength = Math.round(totalInputLength / data.length);
stats.avgOutputLength = Math.round(totalOutputLength / data.length);
return stats;
}
async pretrain(data, epochs = 1) {
console.log('\n=== 开始预训练 ===');
console.log(`数据量: ${data.length}, 训练轮数: ${epochs}`);
const processedData = this.preprocessData(data, 'pretrain');
for (let epoch = 0; epoch < epochs; epoch++) {
console.log(`\n--- 预训练 Epoch ${epoch + 1}/${epochs} ---`);
let totalLoss = 0;
const batchSize = this.config.trainingConfig.batchSize || 32;
const numBatches = Math.ceil(processedData.length / batchSize);
for (let batchIdx = 0; batchIdx < numBatches; batchIdx++) {
const batch = processedData.slice(
batchIdx * batchSize,
(batchIdx + 1) * batchSize
);
const batchLoss = this.trainBatch(batch, 'pretrain');
totalLoss += batchLoss;
if (batchIdx % 100 === 0) {
console.log(`Batch ${batchIdx}/${numBatches}, Loss: ${batchLoss.toFixed(4)}`);
}
}
const avgLoss = totalLoss / numBatches;
console.log(`Epoch ${epoch + 1} 平均损失: ${avgLoss.toFixed(4)}`);
this.trainingHistory.push({
epoch: epoch + 1,
stage: 'pretrain',
loss: avgLoss,
timestamp: new Date().toISOString()
});
}
console.log('预训练完成');
}
async supervisedFineTune(data, epochs = 3) {
console.log('\n=== 开始监督微调 ===');
console.log(`数据量: ${data.length}, 训练轮数: ${epochs}`);
const processedData = this.preprocessData(data, 'sft');
const originalLR = this.config.trainingConfig.learningRate || 1e-4;
this.config.trainingConfig.learningRate = originalLR * 0.1;
for (let epoch = 0; epoch < epochs; epoch++) {
console.log(`\n--- SFT Epoch ${epoch + 1}/${epochs} ---`);
let totalLoss = 0;
const batchSize = this.config.trainingConfig.batchSize || 16;
const numBatches = Math.ceil(processedData.length / batchSize);
for (let batchIdx = 0; batchIdx < numBatches; batchIdx++) {
const batch = processedData.slice(
batchIdx * batchSize,
(batchIdx + 1) * batchSize
);
const batchLoss = this.trainBatch(batch, 'sft');
totalLoss += batchLoss;
if (batchIdx % 50 === 0) {
console.log(`Batch ${batchIdx}/${numBatches}, Loss: ${batchLoss.toFixed(4)}`);
}
}
const avgLoss = totalLoss / numBatches;
console.log(`Epoch ${epoch + 1} 平均损失: ${avgLoss.toFixed(4)}`);
this.trainingHistory.push({
epoch: epoch + 1,
stage: 'sft',
loss: avgLoss,
timestamp: new Date().toISOString()
});
if (this.shouldEarlyStop()) {
console.log('触发早停,结束训练');
break;
}
}
console.log('监督微调完成');
}
trainBatch(batch, stage) {
let batchLoss = 0;
batch.forEach(item => {
const loss = this.computeLoss(item, stage);
batchLoss += loss;
this.updateParameters(loss);
});
return batchLoss / batch.length;
}
computeLoss(item, stage) {
switch (stage) {
case 'pretrain':
return this.computeLanguageModelingLoss(item);
case 'sft':
return this.computeSupervisionLoss(item);
case 'rlhf':
return this.computeRLHFLoss(item);
default:
return Math.random();
}
}
computeLanguageModelingLoss(item) {
const inputTokens = item.input_tokens || [];
const targetTokens = item.target_tokens || [];
if (inputTokens.length === 0 || targetTokens.length === 0) {
return 1.0;
}
let loss = 0;
for (let i = 0; i < Math.min(inputTokens.length, targetTokens.length); i++) {
const predicted = inputTokens[i];
const target = targetTokens[i];
loss += -Math.log(Math.max(0.001, 1 - Math.abs(predicted - target) / 1000));
}
return loss / Math.min(inputTokens.length, targetTokens.length);
}
computeSupervisionLoss(item) {
const outputTokens = item.output_tokens || [];
if (outputTokens.length === 0) {
return 1.0;
}
return Math.random() * 0.5 + 0.1;
}
computeRLHFLoss(item) {
return Math.random() * 0.3 + 0.05;
}
updateParameters(loss) {
const lr = this.config.trainingConfig.learningRate || 1e-4;
return true;
}
shouldEarlyStop() {
if (this.trainingHistory.length < 3) return false;
const recentLosses = this.trainingHistory.slice(-3).map(h => h.loss);
const isIncreasing = recentLosses[1] > recentLosses[0] && recentLosses[2] > recentLosses[1];
return isIncreasing;
}
saveTrainingHistory(filepath) {
const historyData = {
config: this.config,
history: this.trainingHistory,
timestamp: new Date().toISOString()
};
fs.writeFileSync(filepath, JSON.stringify(historyData, null, 2));
console.log(`训练历史已保存到: ${filepath}`);
}
visualizeTraining() {
console.log('\n=== 训练过程可视化 ===');
if (this.trainingHistory.length === 0) {
console.log('没有训练历史数据');
return;
}
const stageGroups = {};
this.trainingHistory.forEach(record => {
if (!stageGroups[record.stage]) {
stageGroups[record.stage] = [];
}
stageGroups[record.stage].push(record);
});
Object.entries(stageGroups).forEach(([stage, records]) => {
console.log(`\n${stage.toUpperCase()}阶段损失曲线:`);
console.log('Epoch\tLoss\t\tTrend');
console.log(''.padEnd(30, '-'));
records.forEach((record, idx) => {
const trend = idx > 0 ?
(record.loss < records[idx - 1].loss ? '↓' : '↑') : '-';
console.log(`${record.epoch}\t${record.loss.toFixed(4)}\t\t${trend}`);
});
});
const allLosses = this.trainingHistory.map(h => h.loss);
const minLoss = Math.min(...allLosses);
const maxLoss = Math.max(...allLosses);
const finalLoss = allLosses[allLosses.length - 1];
console.log('\n=== 训练统计 ===');
console.log(`总训练轮数: ${this.trainingHistory.length}`);
console.log(`最低损失: ${minLoss.toFixed(4)}`);
console.log(`最高损失: ${maxLoss.toFixed(4)}`);
console.log(`最终损失: ${finalLoss.toFixed(4)}`);
console.log(`损失改善: ${((maxLoss - finalLoss) / maxLoss * 100).toFixed(2)}%`);
}
}
class LoRATrainer {
constructor(baseModel, config) {
this.baseModel = baseModel;
this.config = {
rank: config.rank || 16,
alpha: config.alpha || 32,
dropout: config.dropout || 0.1,
targetModules: config.targetModules || ['q_proj', 'v_proj']
};
this.loraLayers = {};
this.initializeLoRALayers();
console.log('LoRA训练器初始化完成');
this.printLoRAInfo();
}
initializeLoRALayers() {
console.log('初始化LoRA层...');
this.config.targetModules.forEach(moduleName => {
this.loraLayers[moduleName] = {
A: this.initializeMatrix(this.config.rank, 768),
B: this.initializeMatrix(768, this.config.rank),
scaling: this.config.alpha / this.config.rank
};
});
}
initializeMatrix(rows, cols) {
const matrix = [];
const scale = Math.sqrt(2 / (rows + cols));
for (let i = 0; i < rows; i++) {
const row = [];
for (let j = 0; j < cols; j++) {
row.push((Math.random() - 0.5) * 2 * scale);
}
matrix.push(row);
}
return matrix;
}
printLoRAInfo() {
console.log('\n=== LoRA配置信息 ===');
console.log(`秩(Rank): ${this.config.rank}`);
console.log(`缩放因子(Alpha): ${this.config.alpha}`);
console.log(`Dropout: ${this.config.dropout}`);
console.log(`目标模块: [${this.config.targetModules.join(', ')}]`);
const originalParams = 768 * 768 * this.config.targetModules.length;
const loraParams = (768 + 768) * this.config.rank * this.config.targetModules.length;
const reduction = ((originalParams - loraParams) / originalParams * 100).toFixed(2);
console.log(`\n参数量对比:`);
console.log(`原始参数: ${originalParams.toLocaleString()}`);
console.log(`LoRA参数: ${loraParams.toLocaleString()}`);
console.log(`参数减少: ${reduction}%`);
}
loraForward(x, moduleName) {
if (!this.loraLayers[moduleName]) {
return x;
}
const lora = this.loraLayers[moduleName];
const xA = this.matmul(x, lora.A);
const xAB = this.matmul(xA, lora.B);
const loraOutput = xAB.map(row =>
row.map(val => val * lora.scaling)
);
return this.addMatrices(x, loraOutput);
}
matmul(a, b) {
const result = [];
for (let i = 0; i < a.length; i++) {
const row = [];
for (let j = 0; j < b[0].length; j++) {
let sum = 0;
for (let k = 0; k < b.length; k++) {
sum += a[i][k] * b[k][j];
}
row.push(sum);
}
result.push(row);
}
return result;
}
addMatrices(a, b) {
return a.map((row, i) =>
row.map((val, j) => val + (b[i] ? b[i][j] || 0 : 0))
);
}
async fineTune(data, epochs = 5) {
console.log('\n=== 开始LoRA微调 ===');
console.log(`数据量: ${data.length}, 训练轮数: ${epochs}`);
for (let epoch = 0; epoch < epochs; epoch++) {
console.log(`\n--- LoRA Epoch ${epoch + 1}/${epochs} ---`);
let totalLoss = 0;
for (let i = 0; i < data.length; i++) {
const item = data[i];
const loss = this.computeLoRALoss(item);
totalLoss += loss;
this.updateLoRAParameters(loss);
if (i % 100 === 0) {
console.log(`Sample ${i}/${data.length}, Loss: ${loss.toFixed(4)}`);
}
}
const avgLoss = totalLoss / data.length;
console.log(`Epoch ${epoch + 1} 平均损失: ${avgLoss.toFixed(4)}`);
}
console.log('LoRA微调完成');
}
computeLoRALoss(item) {
return Math.random() * 0.5 + 0.1;
}
updateLoRAParameters(loss) {
const lr = 1e-4;
Object.values(this.loraLayers).forEach(lora => {
lora.A = lora.A.map(row =>
row.map(val => val - lr * (Math.random() - 0.5) * loss)
);
lora.B = lora.B.map(row =>
row.map(val => val - lr * (Math.random() - 0.5) * loss)
);
});
}
saveLoRAWeights(filepath) {
const loraData = {
config: this.config,
weights: this.loraLayers,
timestamp: new Date().toISOString()
};
fs.writeFileSync(filepath, JSON.stringify(loraData, null, 2));
console.log(`LoRA权重已保存到: ${filepath}`);
}
}
const trainingConfig = {
modelConfig: {
vocabSize: 50000,
hiddenSize: 768,
numLayers: 12,
numHeads: 12
},
trainingConfig: {
batchSize: 32,
learningRate: 1e-4,
maxSteps: 10000
},
dataConfig: {
maxLength: 512,
dataPath: './training_data'
}
};
const trainer = new LLMTrainer(trainingConfig);
const pretrainData = [
'这是一个用于预训练的文本示例。',
'大语言模型需要大量的文本数据进行训练。',
'预训练阶段学习语言的统计规律和知识。'
];
const sftData = [
{
instruction: '请解释什么是机器学习',
input: '',
output: '机器学习是人工智能的一个分支,通过算法让计算机从数据中学习模式。'
},
{
instruction: '写一首关于春天的诗',
input: '',
output: '春风吹绿柳,花开满枝头。鸟儿歌声美,大地换新装。'
}
];
async function runTraining() {
try {
await trainer.pretrain(pretrainData, 2);
await trainer.supervisedFineTune(sftData, 3);
trainer.visualizeTraining();
trainer.saveTrainingHistory('./training_history.json');
const loraConfig = {
rank: 16,
alpha: 32,
dropout: 0.1,
targetModules: ['q_proj', 'v_proj', 'k_proj', 'o_proj']
};
const loraTrainer = new LoRATrainer(null, loraConfig);
await loraTrainer.fineTune(sftData, 3);
loraTrainer.saveLoRAWeights('./lora_weights.json');
} catch (error) {
console.error('训练过程中出现错误:', error);
}
}
runTraining();