trainModelFromDataset function domain exported

Last updated: 2026-02-24T19:46:21.733Z

Metrics

LOC: 192 Complexity: 16 Params: 2

Signature

trainModelFromDataset( entries: DatasetEntry[], onProgress?: (p: TrainingProgress) => void, ): : Promise<TrainingResult>

Summary

Trains a MLP classifier from the provided dataset entries, stores the resulting model in chrome.storage.local and returns training metrics. Must be called from an extension page (options/popup) — NOT from a content script, because TF.js WebGL backend requires a live DOM context.

Architecture violations

View all

  • [warning] max-cyclomatic-complexity: 'trainModelFromDataset' has cyclomatic complexity 16 (max 10)
  • [warning] max-lines: 'trainModelFromDataset' has 192 lines (max 80)

Tags

#@param entries Dataset entries to train on.#@param onProgress Optional callback invoked after every epoch.

Source Code

export async function trainModelFromDataset(
  entries: DatasetEntry[],
  onProgress?: (p: TrainingProgress) => void,
): Promise<TrainingResult> {
  const t0 = Date.now();

  // Filter to supported labels only + convert to structured feature text
  const samples = entries
    .filter((e) => LABEL_SET.has(e.type))
    .map((entry) => {
      const structured = fromFlatSignals(entry.signals);
      const category = inferCategoryFromType(entry.type);
      const language = inferLanguageFromSignals(entry.signals);
      const featureText = buildFeatureText(structured, {
        category,
        language,
      });

      return {
        ...entry,
        featureText,
      };
    })
    .filter((entry) => entry.featureText.length > 0);

  if (samples.length < 10) {
    return {
      success: false,
      error: `Dataset muito pequeno (${samples.length} amostras válidas). Mínimo: 10.`,
      trainedAt: Date.now(),
      epochs: 0,
      finalLoss: 0,
      finalAccuracy: 0,
      vocabSize: 0,
      numClasses: 0,
      entriesUsed: 0,
      durationMs: 0,
    };
  }

  // Determine labels actually present in this dataset
  const presentLabels = LABELS.filter((l) => samples.some((s) => s.type === l));
  const labelToIdx = Object.fromEntries(presentLabels.map((l, i) => [l, i]));
  const numClasses = presentLabels.length;

  if (numClasses < 2) {
    return {
      success: false,
      error: `O dataset precisa ter pelo menos 2 tipos de campo diferentes para treinar. Encontrado: ${numClasses} tipo (${presentLabels[0] ?? "nenhum"}). Adicione amostras de outros tipos ou importe o dataset padrão.`,
      trainedAt: Date.now(),
      epochs: 0,
      finalLoss: 0,
      finalAccuracy: 0,
      vocabSize: 0,
      numClasses,
      entriesUsed: samples.length,
      durationMs: 0,
    };
  }

  // Build vocab from all structured feature texts
  const featureTexts = samples.map((sample) => sample.featureText);
  const vocab = buildVocab(featureTexts);
  const vocabSize = vocab.size;

  // Vectorise
  const X = samples.map((sample) =>
    Array.from(vectorize(sample.featureText, vocab)),
  );
  const Y = samples.map((s) => {
    const oneHot = new Array<number>(numClasses).fill(0);
    oneHot[labelToIdx[s.type]] = 1;
    return oneHot;
  });

  const tf = await import("@tensorflow/tfjs");
  await tf.ready();

  const xTensor = tf.tensor2d(X, [X.length, vocabSize]);
  const yTensor = tf.tensor2d(Y, [Y.length, numClasses]);

  // Model architecture (mirrors train-model.ts)
  const model = tf.sequential({
    layers: [
      tf.layers.dense({
        inputShape: [vocabSize],
        units: 256,
        activation: "relu",
        kernelRegularizer: tf.regularizers.l2({ l2: 1e-4 }),
      }),
      tf.layers.dropout({ rate: 0.3 }),
      tf.layers.dense({
        units: 128,
        activation: "relu",
        kernelRegularizer: tf.regularizers.l2({ l2: 1e-4 }),
      }),
      tf.layers.dropout({ rate: 0.2 }),
      tf.layers.dense({ units: numClasses, activation: "softmax" }),
    ],
  });

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: "categoricalCrossentropy",
    metrics: ["accuracy"],
  });

  // Early stopping with best-weights checkpoint
  let bestAcc = -1;
  let patience = 0;
  const bestWeights: Tensor[] = [];

  let finalLoss = 0;
  let finalAccuracy = 0;
  let lastEpoch = 0;

  const onEpochEnd = async (epoch: number, logs?: Record<string, number>) => {
    const acc = logs?.["acc"] ?? logs?.["accuracy"] ?? 0;
    const loss = logs?.["loss"] ?? 0;

    lastEpoch = epoch + 1;
    finalLoss = loss;
    finalAccuracy = acc;

    if (acc > bestAcc) {
      bestAcc = acc;
      patience = 0;
      while (bestWeights.length) bestWeights.pop()!.dispose();
      model.getWeights().forEach((w) => bestWeights.push(w.clone()));
    } else {
      patience++;
      if (patience >= PATIENCE) model.stopTraining = true;
    }

    onProgress?.({
      epoch: epoch + 1,
      totalEpochs: EPOCHS,
      loss,
      accuracy: acc,
    });
  };

  try {
    await model.fit(xTensor, yTensor, {
      epochs: EPOCHS,
      batchSize: BATCH_SIZE,
      shuffle: true,
      callbacks: [new tf.CustomCallback({ onEpochEnd })],
    });

    // Restore best weights
    if (bestWeights.length > 0) {
      model.setWeights(bestWeights);
      while (bestWeights.length) bestWeights.pop()!.dispose();
    }

    xTensor.dispose();
    yTensor.dispose();

    const meta: TrainingMeta = {
      trainedAt: Date.now(),
      epochs: lastEpoch,
      finalLoss,
      finalAccuracy: bestAcc >= 0 ? bestAcc : finalAccuracy,
      vocabSize,
      numClasses,
      entriesUsed: samples.length,
      durationMs: Date.now() - t0,
    };

    await saveModelToStorage(model, vocab, presentLabels, meta);
    model.dispose();

    return { success: true, ...meta };
  } catch (err) {
    xTensor.dispose();
    yTensor.dispose();
    model.dispose();
    return {
      success: false,
      error: err instanceof Error ? err.message : String(err),
      trainedAt: Date.now(),
      epochs: lastEpoch,
      finalLoss,
      finalAccuracy,
      vocabSize,
      numClasses,
      entriesUsed: samples.length,
      durationMs: Date.now() - t0,
    };
  }
}

Dependencies (Outgoing)

graph LR trainModelFromDataset["trainModelFromDataset"] buildVocab["buildVocab"] vectorize["vectorize"] saveModelToStorage["saveModelToStorage"] trainModelFromDataset -->|calls| buildVocab trainModelFromDataset -->|calls| vectorize trainModelFromDataset -->|calls| saveModelToStorage style trainModelFromDataset fill:#dbeafe,stroke:#2563eb,stroke-width:2px click trainModelFromDataset "116a4fd1e25c7132.html" click buildVocab "775e3614ccc16f3a.html" click vectorize "20cf7613c5e8a682.html" click saveModelToStorage "62ae3258e166f5da.html"
TargetType
buildVocab calls
vectorize calls
saveModelToStorage calls

Impact (Incoming)

graph LR trainModelFromDataset["trainModelFromDataset"] loadModelStatus["loadModelStatus"] loadModelStatus -->|uses| trainModelFromDataset style trainModelFromDataset fill:#dbeafe,stroke:#2563eb,stroke-width:2px click trainModelFromDataset "116a4fd1e25c7132.html" click loadModelStatus "2c3f5e65d67a73f9.html"
SourceType
loadModelStatus uses