src/lib/ai/runtime-trainer.ts

Total Symbols
15
Lines of Code
456
Avg Complexity
3.1
Symbol Types
2

File Relationships

graph LR saveModelToStorage["saveModelToStorage"] arrayBufferToBase64["arrayBufferToBase64"] loadRuntimeModel["loadRuntimeModel"] base64ToArrayBuffer["base64ToArrayBuffer"] trainModelFromDataset["trainModelFromDataset"] buildVocab["buildVocab"] saveModelToStorage -->|calls| arrayBufferToBase64 loadRuntimeModel -->|calls| base64ToArrayBuffer trainModelFromDataset -->|calls| buildVocab trainModelFromDataset -->|calls| saveModelToStorage click saveModelToStorage "../symbols/62ae3258e166f5da.html" click arrayBufferToBase64 "../symbols/1f04fa87078e4837.html" click loadRuntimeModel "../symbols/83973a574fe66b54.html" click base64ToArrayBuffer "../symbols/7549b5b6cce3a6c9.html" click trainModelFromDataset "../symbols/116a4fd1e25c7132.html" click buildVocab "../symbols/775e3614ccc16f3a.html"

Architecture violations

View all

  • [warning] max-cyclomatic-complexity: 'trainModelFromDataset' has cyclomatic complexity 16 (max 10)
  • [warning] max-parameters: 'saveModelToStorage' has 6 parameters (max 5)
  • [warning] max-lines: 'trainModelFromDataset' has 192 lines (max 80)

Symbols by Kind

function 11
interface 4

All Symbols

Name Kind Visibility Status Lines Signature
TrainingProgress interface exported- 46-53 interface TrainingProgress
TrainingMeta interface exported- 55-64 interface TrainingMeta
TrainingResult interface exported- 66-69 interface TrainingResult
StoredModelArtifacts interface exported- 71-75 interface StoredModelArtifacts
charNgrams function - 91-105 charNgrams(text: string): : string[]
buildVocab function - 107-115 buildVocab(texts: string[]): : Map<string, number>
vectorize function - 117-128 vectorize(text: string, vocab: Map<string, number>): : Float32Array
arrayBufferToBase64 function - 130-137 arrayBufferToBase64(buffer: ArrayBuffer): : string
base64ToArrayBuffer function - 139-147 base64ToArrayBuffer(b64: string): : ArrayBuffer
saveModelToStorage function - 151-185 saveModelToStorage( model: LayersModel, vocab: Map<string, number>, labels: FieldType[], meta: TrainingMeta, ): : Promise<void>
loadRuntimeModel function exported- 191-228 loadRuntimeModel(): : Promise<{ model: LayersModel; vocab: Map<string, number>; labels: FieldType[]; } | null>
hasRuntimeModel function exported- 231-234 hasRuntimeModel(): : Promise<boolean>
getRuntimeModelMeta function exported- 237-240 getRuntimeModelMeta(): : Promise<TrainingMeta | null>
deleteRuntimeModel function exported- 243-250 deleteRuntimeModel(): : Promise<void>
trainModelFromDataset function exported- 264-455 trainModelFromDataset( entries: DatasetEntry[], onProgress?: (p: TrainingProgress) => void, ): : Promise<TrainingResult>

Full Source

/**
 * Runtime Model Trainer
 *
 * Trains a TF.js MLP classifier from user-curated dataset entries entirely
 * inside the browser (options page). The trained model is serialised and
 * stored in chrome.storage.local so it can be loaded by the content script
 * on any page.
 *
 * Architecture mirrors the offline train-model.ts script so the bundled
 * model and the runtime-trained model are interchangeable.
 *
 *   signals → Dense(256, relu) → Dropout(0.3)
 *           → Dense(128, relu) → Dropout(0.2)
 *           → Dense(NUM_CLASSES, softmax)
 *
 * Storage layout (chrome.storage.local):
 *   fill_all_runtime_model   — { topology, weightSpecs, weightDataB64 }
 *   fill_all_runtime_vocab   — { [ngram]: index }
 *   fill_all_runtime_labels  — string[]   (ordered FieldType array)
 *   fill_all_runtime_meta    — { trainedAt, epochs, loss, accuracy, vocabSize, numClasses }
 */

import type { FieldType } from "@/types";
import { TRAINABLE_FIELD_TYPES } from "@/types";
import type { DatasetEntry } from "@/lib/dataset/runtime-dataset";
import type { LayersModel, Tensor } from "@tensorflow/tfjs";
import { createLogger } from "@/lib/logger";
import {
  buildFeatureText,
  fromFlatSignals,
  inferCategoryFromType,
  inferLanguageFromSignals,
} from "@/lib/shared/structured-signals";

const log = createLogger("RuntimeTrainer");

// ── Keys ────────────────────────────────────────────────────────────────────

export const RUNTIME_MODEL_KEY = "fill_all_runtime_model";
export const RUNTIME_VOCAB_KEY = "fill_all_runtime_vocab";
export const RUNTIME_LABELS_KEY = "fill_all_runtime_labels";
export const RUNTIME_META_KEY = "fill_all_runtime_meta";

// ── Types ────────────────────────────────────────────────────────────────────

export interface TrainingProgress {
  epoch: number;
  totalEpochs: number;
  loss: number;
  accuracy: number;
  valLoss?: number;
  valAccuracy?: number;
}

export interface TrainingMeta {
  trainedAt: number;
  epochs: number;
  finalLoss: number;
  finalAccuracy: number;
  vocabSize: number;
  numClasses: number;
  entriesUsed: number;
  durationMs: number;
}

export interface TrainingResult extends TrainingMeta {
  success: boolean;
  error?: string;
}

export interface StoredModelArtifacts {
  topology: unknown;
  weightSpecs: unknown[];
  weightDataB64: string;
}

// ── Constants ─────────────────────────────────────────────────────────────────

const NGRAM_SIZE = 3;
const EPOCHS = 80;
const BATCH_SIZE = 32;
const PATIENCE = 20;

// All trainable FieldType labels (must match train-model.ts)
const LABELS: FieldType[] = [...TRAINABLE_FIELD_TYPES];

const LABEL_SET = new Set<string>(LABELS);

// ── Helpers ──────────────────────────────────────────────────────────────────

function charNgrams(text: string): string[] {
  const normalized = text
    .toLowerCase()
    .normalize("NFD")
    .replace(/[\u0300-\u036f]/g, "")
    .replace(/[_\-/.]+/g, " ")
    .replace(/\s+/g, " ")
    .trim();
  const padded = `_${normalized}_`;
  const result: string[] = [];
  for (let i = 0; i <= padded.length - NGRAM_SIZE; i++) {
    result.push(padded.slice(i, i + NGRAM_SIZE));
  }
  return result;
}

function buildVocab(texts: string[]): Map<string, number> {
  const vocab = new Map<string, number>();
  for (const text of texts) {
    for (const ng of charNgrams(text)) {
      if (!vocab.has(ng)) vocab.set(ng, vocab.size);
    }
  }
  return vocab;
}

function vectorize(text: string, vocab: Map<string, number>): Float32Array {
  const v = new Float32Array(vocab.size);
  for (const ng of charNgrams(text)) {
    const idx = vocab.get(ng);
    if (idx !== undefined) v[idx] += 1;
  }
  let norm = 0;
  for (let i = 0; i < v.length; i++) norm += v[i] * v[i];
  norm = Math.sqrt(norm);
  if (norm > 0) for (let i = 0; i < v.length; i++) v[i] /= norm;
  return v;
}

function arrayBufferToBase64(buffer: ArrayBuffer): string {
  const bytes = new Uint8Array(buffer);
  let binary = "";
  for (let i = 0; i < bytes.byteLength; i++) {
    binary += String.fromCharCode(bytes[i]);
  }
  return btoa(binary);
}

function base64ToArrayBuffer(b64: string): ArrayBuffer {
  const binary = atob(b64);
  const buffer = new ArrayBuffer(binary.length);
  const view = new Uint8Array(buffer);
  for (let i = 0; i < binary.length; i++) {
    view[i] = binary.charCodeAt(i);
  }
  return buffer;
}

// ── Model persistence ─────────────────────────────────────────────────────────

async function saveModelToStorage(
  model: LayersModel,
  vocab: Map<string, number>,
  labels: FieldType[],
  meta: TrainingMeta,
): Promise<void> {
  // Custom IO handler: saves model artifacts to chrome.storage.local
  const tf = await import("@tensorflow/tfjs");
  const storageHandler: import("@tensorflow/tfjs").io.IOHandler = {
    save: async (artifacts) => {
      const weightData = artifacts.weightData as ArrayBuffer;
      const stored: StoredModelArtifacts = {
        topology: artifacts.modelTopology,
        weightSpecs: (artifacts.weightSpecs ?? []) as unknown[],
        weightDataB64: arrayBufferToBase64(weightData),
      };
      const vocabObj = Object.fromEntries(vocab);
      await chrome.storage.local.set({
        [RUNTIME_MODEL_KEY]: stored,
        [RUNTIME_VOCAB_KEY]: vocabObj,
        [RUNTIME_LABELS_KEY]: labels,
        [RUNTIME_META_KEY]: meta,
      });
      return {
        modelArtifactsInfo: {
          dateSaved: new Date(),
          modelTopologyType: "JSON" as const,
          weightDataBytes: weightData.byteLength,
        },
      };
    },
  };

  await model.save(storageHandler);
}

/**
 * Loads the runtime-trained model from chrome.storage.local.
 * Returns null if no runtime model has been trained yet.
 */
export async function loadRuntimeModel(): Promise<{
  model: LayersModel;
  vocab: Map<string, number>;
  labels: FieldType[];
} | null> {
  const result = await chrome.storage.local.get([
    RUNTIME_MODEL_KEY,
    RUNTIME_VOCAB_KEY,
    RUNTIME_LABELS_KEY,
  ]);

  const stored = result[RUNTIME_MODEL_KEY] as StoredModelArtifacts | undefined;
  const vocabObj = result[RUNTIME_VOCAB_KEY] as
    | Record<string, number>
    | undefined;
  const labelsArr = result[RUNTIME_LABELS_KEY] as string[] | undefined;

  if (!stored || !vocabObj || !labelsArr) return null;

  try {
    const tf = await import("@tensorflow/tfjs");
    const loadHandler: import("@tensorflow/tfjs").io.IOHandler = {
      load: async () => ({
        modelTopology: stored.topology as {},
        weightSpecs:
          stored.weightSpecs as import("@tensorflow/tfjs").io.WeightsManifestEntry[],
        weightData: base64ToArrayBuffer(stored.weightDataB64),
        format: "layers-model",
      }),
    };
    const model = await tf.loadLayersModel(loadHandler);
    const vocab = new Map(Object.entries(vocabObj));
    return { model, vocab, labels: labelsArr as FieldType[] };
  } catch (err) {
    log.warn("Falha ao carregar modelo do storage:", err);
    return null;
  }
}

/** Returns true if a runtime-trained model exists in storage. */
export async function hasRuntimeModel(): Promise<boolean> {
  const result = await chrome.storage.local.get(RUNTIME_MODEL_KEY);
  return !!result[RUNTIME_MODEL_KEY];
}

/** Returns the training metadata for the stored model, or null. */
export async function getRuntimeModelMeta(): Promise<TrainingMeta | null> {
  const result = await chrome.storage.local.get(RUNTIME_META_KEY);
  return (result[RUNTIME_META_KEY] as TrainingMeta) ?? null;
}

/** Deletes the runtime-trained model from storage. */
export async function deleteRuntimeModel(): Promise<void> {
  await chrome.storage.local.remove([
    RUNTIME_MODEL_KEY,
    RUNTIME_VOCAB_KEY,
    RUNTIME_LABELS_KEY,
    RUNTIME_META_KEY,
  ]);
}

// ── Trainer ──────────────────────────────────────────────────────────────────

/**
 * Trains a MLP classifier from the provided dataset entries, stores the
 * resulting model in chrome.storage.local and returns training metrics.
 *
 * Must be called from an extension page (options/popup) — NOT from a content
 * script, because TF.js WebGL backend requires a live DOM context.
 *
 * @param entries  Dataset entries to train on.
 * @param onProgress  Optional callback invoked after every epoch.
 */
export async function trainModelFromDataset(
  entries: DatasetEntry[],
  onProgress?: (p: TrainingProgress) => void,
): Promise<TrainingResult> {
  const t0 = Date.now();

  // Filter to supported labels only + convert to structured feature text
  const samples = entries
    .filter((e) => LABEL_SET.has(e.type))
    .map((entry) => {
      const structured = fromFlatSignals(entry.signals);
      const category = inferCategoryFromType(entry.type);
      const language = inferLanguageFromSignals(entry.signals);
      const featureText = buildFeatureText(structured, {
        category,
        language,
      });

      return {
        ...entry,
        featureText,
      };
    })
    .filter((entry) => entry.featureText.length > 0);

  if (samples.length < 10) {
    return {
      success: false,
      error: `Dataset muito pequeno (${samples.length} amostras válidas). Mínimo: 10.`,
      trainedAt: Date.now(),
      epochs: 0,
      finalLoss: 0,
      finalAccuracy: 0,
      vocabSize: 0,
      numClasses: 0,
      entriesUsed: 0,
      durationMs: 0,
    };
  }

  // Determine labels actually present in this dataset
  const presentLabels = LABELS.filter((l) => samples.some((s) => s.type === l));
  const labelToIdx = Object.fromEntries(presentLabels.map((l, i) => [l, i]));
  const numClasses = presentLabels.length;

  if (numClasses < 2) {
    return {
      success: false,
      error: `O dataset precisa ter pelo menos 2 tipos de campo diferentes para treinar. Encontrado: ${numClasses} tipo (${presentLabels[0] ?? "nenhum"}). Adicione amostras de outros tipos ou importe o dataset padrão.`,
      trainedAt: Date.now(),
      epochs: 0,
      finalLoss: 0,
      finalAccuracy: 0,
      vocabSize: 0,
      numClasses,
      entriesUsed: samples.length,
      durationMs: 0,
    };
  }

  // Build vocab from all structured feature texts
  const featureTexts = samples.map((sample) => sample.featureText);
  const vocab = buildVocab(featureTexts);
  const vocabSize = vocab.size;

  // Vectorise
  const X = samples.map((sample) =>
    Array.from(vectorize(sample.featureText, vocab)),
  );
  const Y = samples.map((s) => {
    const oneHot = new Array<number>(numClasses).fill(0);
    oneHot[labelToIdx[s.type]] = 1;
    return oneHot;
  });

  const tf = await import("@tensorflow/tfjs");
  await tf.ready();

  const xTensor = tf.tensor2d(X, [X.length, vocabSize]);
  const yTensor = tf.tensor2d(Y, [Y.length, numClasses]);

  // Model architecture (mirrors train-model.ts)
  const model = tf.sequential({
    layers: [
      tf.layers.dense({
        inputShape: [vocabSize],
        units: 256,
        activation: "relu",
        kernelRegularizer: tf.regularizers.l2({ l2: 1e-4 }),
      }),
      tf.layers.dropout({ rate: 0.3 }),
      tf.layers.dense({
        units: 128,
        activation: "relu",
        kernelRegularizer: tf.regularizers.l2({ l2: 1e-4 }),
      }),
      tf.layers.dropout({ rate: 0.2 }),
      tf.layers.dense({ units: numClasses, activation: "softmax" }),
    ],
  });

  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: "categoricalCrossentropy",
    metrics: ["accuracy"],
  });

  // Early stopping with best-weights checkpoint
  let bestAcc = -1;
  let patience = 0;
  const bestWeights: Tensor[] = [];

  let finalLoss = 0;
  let finalAccuracy = 0;
  let lastEpoch = 0;

  const onEpochEnd = async (epoch: number, logs?: Record<string, number>) => {
    const acc = logs?.["acc"] ?? logs?.["accuracy"] ?? 0;
    const loss = logs?.["loss"] ?? 0;

    lastEpoch = epoch + 1;
    finalLoss = loss;
    finalAccuracy = acc;

    if (acc > bestAcc) {
      bestAcc = acc;
      patience = 0;
      while (bestWeights.length) bestWeights.pop()!.dispose();
      model.getWeights().forEach((w) => bestWeights.push(w.clone()));
    } else {
      patience++;
      if (patience >= PATIENCE) model.stopTraining = true;
    }

    onProgress?.({
      epoch: epoch + 1,
      totalEpochs: EPOCHS,
      loss,
      accuracy: acc,
    });
  };

  try {
    await model.fit(xTensor, yTensor, {
      epochs: EPOCHS,
      batchSize: BATCH_SIZE,
      shuffle: true,
      callbacks: [new tf.CustomCallback({ onEpochEnd })],
    });

    // Restore best weights
    if (bestWeights.length > 0) {
      model.setWeights(bestWeights);
      while (bestWeights.length) bestWeights.pop()!.dispose();
    }

    xTensor.dispose();
    yTensor.dispose();

    const meta: TrainingMeta = {
      trainedAt: Date.now(),
      epochs: lastEpoch,
      finalLoss,
      finalAccuracy: bestAcc >= 0 ? bestAcc : finalAccuracy,
      vocabSize,
      numClasses,
      entriesUsed: samples.length,
      durationMs: Date.now() - t0,
    };

    await saveModelToStorage(model, vocab, presentLabels, meta);
    model.dispose();

    return { success: true, ...meta };
  } catch (err) {
    xTensor.dispose();
    yTensor.dispose();
    model.dispose();
    return {
      success: false,
      error: err instanceof Error ? err.message : String(err),
      trainedAt: Date.now(),
      epochs: lastEpoch,
      finalLoss,
      finalAccuracy,
      vocabSize,
      numClasses,
      entriesUsed: samples.length,
      durationMs: Date.now() - t0,
    };
  }
}