trainModelFromDataset function domain exported
Last updated: 2026-02-24T19:46:21.733Z
Metrics
LOC: 192
Complexity: 16
Params: 2
Signature
trainModelFromDataset(
entries: DatasetEntry[],
onProgress?: (p: TrainingProgress) => void,
): : Promise<TrainingResult>
Summary
Trains a MLP classifier from the provided dataset entries, stores the resulting model in chrome.storage.local and returns training metrics. Must be called from an extension page (options/popup) — NOT from a content script, because TF.js WebGL backend requires a live DOM context.
Architecture violations
- [warning] max-cyclomatic-complexity: 'trainModelFromDataset' has cyclomatic complexity 16 (max 10)
- [warning] max-lines: 'trainModelFromDataset' has 192 lines (max 80)
Tags
#@param entries Dataset entries to train on.#@param onProgress Optional callback invoked after every epoch.
Source Code
export async function trainModelFromDataset(
entries: DatasetEntry[],
onProgress?: (p: TrainingProgress) => void,
): Promise<TrainingResult> {
const t0 = Date.now();
// Filter to supported labels only + convert to structured feature text
const samples = entries
.filter((e) => LABEL_SET.has(e.type))
.map((entry) => {
const structured = fromFlatSignals(entry.signals);
const category = inferCategoryFromType(entry.type);
const language = inferLanguageFromSignals(entry.signals);
const featureText = buildFeatureText(structured, {
category,
language,
});
return {
...entry,
featureText,
};
})
.filter((entry) => entry.featureText.length > 0);
if (samples.length < 10) {
return {
success: false,
error: `Dataset muito pequeno (${samples.length} amostras válidas). Mínimo: 10.`,
trainedAt: Date.now(),
epochs: 0,
finalLoss: 0,
finalAccuracy: 0,
vocabSize: 0,
numClasses: 0,
entriesUsed: 0,
durationMs: 0,
};
}
// Determine labels actually present in this dataset
const presentLabels = LABELS.filter((l) => samples.some((s) => s.type === l));
const labelToIdx = Object.fromEntries(presentLabels.map((l, i) => [l, i]));
const numClasses = presentLabels.length;
if (numClasses < 2) {
return {
success: false,
error: `O dataset precisa ter pelo menos 2 tipos de campo diferentes para treinar. Encontrado: ${numClasses} tipo (${presentLabels[0] ?? "nenhum"}). Adicione amostras de outros tipos ou importe o dataset padrão.`,
trainedAt: Date.now(),
epochs: 0,
finalLoss: 0,
finalAccuracy: 0,
vocabSize: 0,
numClasses,
entriesUsed: samples.length,
durationMs: 0,
};
}
// Build vocab from all structured feature texts
const featureTexts = samples.map((sample) => sample.featureText);
const vocab = buildVocab(featureTexts);
const vocabSize = vocab.size;
// Vectorise
const X = samples.map((sample) =>
Array.from(vectorize(sample.featureText, vocab)),
);
const Y = samples.map((s) => {
const oneHot = new Array<number>(numClasses).fill(0);
oneHot[labelToIdx[s.type]] = 1;
return oneHot;
});
const tf = await import("@tensorflow/tfjs");
await tf.ready();
const xTensor = tf.tensor2d(X, [X.length, vocabSize]);
const yTensor = tf.tensor2d(Y, [Y.length, numClasses]);
// Model architecture (mirrors train-model.ts)
const model = tf.sequential({
layers: [
tf.layers.dense({
inputShape: [vocabSize],
units: 256,
activation: "relu",
kernelRegularizer: tf.regularizers.l2({ l2: 1e-4 }),
}),
tf.layers.dropout({ rate: 0.3 }),
tf.layers.dense({
units: 128,
activation: "relu",
kernelRegularizer: tf.regularizers.l2({ l2: 1e-4 }),
}),
tf.layers.dropout({ rate: 0.2 }),
tf.layers.dense({ units: numClasses, activation: "softmax" }),
],
});
model.compile({
optimizer: tf.train.adam(0.001),
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
// Early stopping with best-weights checkpoint
let bestAcc = -1;
let patience = 0;
const bestWeights: Tensor[] = [];
let finalLoss = 0;
let finalAccuracy = 0;
let lastEpoch = 0;
const onEpochEnd = async (epoch: number, logs?: Record<string, number>) => {
const acc = logs?.["acc"] ?? logs?.["accuracy"] ?? 0;
const loss = logs?.["loss"] ?? 0;
lastEpoch = epoch + 1;
finalLoss = loss;
finalAccuracy = acc;
if (acc > bestAcc) {
bestAcc = acc;
patience = 0;
while (bestWeights.length) bestWeights.pop()!.dispose();
model.getWeights().forEach((w) => bestWeights.push(w.clone()));
} else {
patience++;
if (patience >= PATIENCE) model.stopTraining = true;
}
onProgress?.({
epoch: epoch + 1,
totalEpochs: EPOCHS,
loss,
accuracy: acc,
});
};
try {
await model.fit(xTensor, yTensor, {
epochs: EPOCHS,
batchSize: BATCH_SIZE,
shuffle: true,
callbacks: [new tf.CustomCallback({ onEpochEnd })],
});
// Restore best weights
if (bestWeights.length > 0) {
model.setWeights(bestWeights);
while (bestWeights.length) bestWeights.pop()!.dispose();
}
xTensor.dispose();
yTensor.dispose();
const meta: TrainingMeta = {
trainedAt: Date.now(),
epochs: lastEpoch,
finalLoss,
finalAccuracy: bestAcc >= 0 ? bestAcc : finalAccuracy,
vocabSize,
numClasses,
entriesUsed: samples.length,
durationMs: Date.now() - t0,
};
await saveModelToStorage(model, vocab, presentLabels, meta);
model.dispose();
return { success: true, ...meta };
} catch (err) {
xTensor.dispose();
yTensor.dispose();
model.dispose();
return {
success: false,
error: err instanceof Error ? err.message : String(err),
trainedAt: Date.now(),
epochs: lastEpoch,
finalLoss,
finalAccuracy,
vocabSize,
numClasses,
entriesUsed: samples.length,
durationMs: Date.now() - t0,
};
}
}
Dependencies (Outgoing)
| Target | Type |
|---|---|
| buildVocab | calls |
| vectorize | calls |
| saveModelToStorage | calls |
Impact (Incoming)
| Source | Type |
|---|---|
| loadModelStatus | uses |