src/lib/form/detectors/strategies/tensorflow-classifier.ts

Total Symbols
11
Lines of Code
382
Avg Complexity
3.7
Avg Coverage
82.3%

File Relationships

graph LR loadPretrainedModel["loadPretrainedModel"] loadTfModule["loadTfModule"] loadLearnedVectors["loadLearnedVectors"] invalidateClassifier["invalidateClassifier"] reloadClassifier["reloadClassifier"] classifyField["classifyField"] classifyByTfSoft["classifyByTfSoft"] detect["detect"] loadPretrainedModel -->|calls| loadTfModule loadPretrainedModel -->|calls| loadLearnedVectors invalidateClassifier -->|calls| loadLearnedVectors reloadClassifier -->|calls| loadPretrainedModel classifyField -->|calls| classifyByTfSoft detect -->|calls| classifyByTfSoft click loadPretrainedModel "../symbols/5945d42bd468f616.html" click loadTfModule "../symbols/0aaf7f2a78a8a2ae.html" click loadLearnedVectors "../symbols/c050c9d2aa02d198.html" click invalidateClassifier "../symbols/a97a4f5efc9940ea.html" click reloadClassifier "../symbols/30b3749d6c005c84.html" click classifyField "../symbols/aa03a8b1140f5f42.html" click classifyByTfSoft "../symbols/0bd31cb0fc7321e5.html" click detect "../symbols/e524d9f6725c7557.html"

Architecture violations

View all

  • [warning] max-cyclomatic-complexity: 'classifyByTfSoft' has cyclomatic complexity 14 (max 10)

Symbols by Kind

function 8
interface 2
method 1

All Symbols

Name Kind Visibility Status Lines Signature
PretrainedState interface - 46-50 interface PretrainedState
LearnedVector interface - 52-55 interface LearnedVector
loadTfModule function - 67-75 loadTfModule(): : Promise<typeof import("@tensorflow/tfjs")>
loadPretrainedModel function exported- 88-144 loadPretrainedModel(): : Promise<void>
loadLearnedVectors function - 150-177 loadLearnedVectors(): : Promise<void>
invalidateClassifier function exported- 183-194 invalidateClassifier(): : void
disposeTensorflowModel function exported- 201-209 disposeTensorflowModel(): : void
reloadClassifier function exported- 215-224 reloadClassifier(): : Promise<void>
classifyByTfSoft function exported- 236-305 classifyByTfSoft( input: string | StructuredSignals, context?: StructuredSignalContext, ): : { type: FieldType; score: number } | null
classifyField function exported- 314-365 classifyField(field: FormField): : FieldType
detect method - 375-380 detect(field: FormField): : ClassifierResult | null

Full Source

/**
 * TensorFlow.js Field Classifier — Detection Strategy
 *
 * Implements the FieldClassifier interface for use in the DetectionPipeline.
 * All classification logic lives here:
 *   - Pre-trained model loading (runtime-trained → bundled fallback)
 *   - Learned-vector lookup (Chrome AI + user corrections via learning-store)
 *   - TF.js softmax inference with cosine-similarity n-gram vectorisation
 *
 * Shared text utilities (charNgrams, vectorize, dotProduct) are imported from
 * src/lib/shared/ngram.ts so they can be independently unit-tested.
 *
 * Configuration, thresholds and log messages live in
 * tensorflow-classifier.config.ts — edit there to tune the classifier.
 *
 * DEBUG: Set `window.__FILL_ALL_DEBUG__ = true` in DevTools and trigger a fill
 * to see per-field classification details.
 */

import type { FieldType, FormField } from "@/types";
import type { LayersModel, Tensor } from "@tensorflow/tfjs";
import { getLearnedEntries } from "@/lib/ai/learning-store";
import { loadRuntimeModel } from "@/lib/ai/runtime-trainer";
import { dotProduct, vectorize } from "@/lib/shared/ngram";
import {
  buildFeatureText,
  fromFlatSignals,
  inferCategoryFromType,
  inferLanguageFromSignals,
  structuredSignalsFromField,
  type StructuredSignalContext,
  type StructuredSignals,
} from "@/lib/shared/structured-signals";
import type { FieldClassifier, ClassifierResult } from "../pipeline";
import { createLogger } from "@/lib/logger";
import { TF_CONFIG, TF_MESSAGES } from "./tensorflow-classifier.config";

export { TF_THRESHOLD } from "./tensorflow-classifier.config";

const log = createLogger("TFClassifier");

const { thresholds } = TF_CONFIG;

// ── Internal types ────────────────────────────────────────────────────────────

interface PretrainedState {
  model: LayersModel;
  vocab: Map<string, number>;
  labels: FieldType[];
}

interface LearnedVector {
  vector: Float32Array;
  type: FieldType;
}

// ── Module state ──────────────────────────────────────────────────────────────

let _pretrained: PretrainedState | null = null;
let _pretrainedLoadPromise: Promise<void> | null = null;
let _learnedVectors: LearnedVector[] = [];
let _tfModule: typeof import("@tensorflow/tfjs") | null = null;
let _tfLoadPromise: Promise<typeof import("@tensorflow/tfjs")> | null = null;

// ── TF.js lazy loader ─────────────────────────────────────────────────────────

async function loadTfModule(): Promise<typeof import("@tensorflow/tfjs")> {
  if (_tfModule) return _tfModule;
  if (_tfLoadPromise) return _tfLoadPromise;
  _tfLoadPromise = import("@tensorflow/tfjs").then((mod) => {
    _tfModule = mod;
    return mod;
  });
  return _tfLoadPromise;
}

// ── Model loading ─────────────────────────────────────────────────────────────

/**
 * Loads the pre-trained TF.js model.
 *
 * Priority:
 *   1. Runtime-trained model stored in chrome.storage.local (via options page)
 *   2. Bundled model files from public/model/ (ship-time default)
 *
 * Safe to call multiple times — subsequent calls are no-ops.
 */
export async function loadPretrainedModel(): Promise<void> {
  if (_pretrained) return;
  if (_pretrainedLoadPromise) return _pretrainedLoadPromise;

  _pretrainedLoadPromise = (async () => {
    try {
      await loadTfModule();

      // Step 1: Try runtime-trained model (user dataset, options page)
      const runtimeModel = await loadRuntimeModel();
      if (runtimeModel) {
        _pretrained = runtimeModel;
        await loadLearnedVectors();
        log.info(
          TF_MESSAGES.modelLoaded.runtime(
            runtimeModel.labels.length,
            runtimeModel.vocab.size,
            _learnedVectors.length,
          ),
        );
        return;
      }

      // Step 2: Fall back to bundled model files
      const tf = await loadTfModule();
      const base = chrome.runtime.getURL("");
      const [model, vocabRaw, labelsRaw] = await Promise.all([
        tf.loadLayersModel(`${base}${TF_CONFIG.model.json}`),
        fetch(`${base}${TF_CONFIG.model.vocab}`).then(
          (r) => r.json() as Promise<Record<string, number>>,
        ),
        fetch(`${base}${TF_CONFIG.model.labels}`).then(
          (r) => r.json() as Promise<string[]>,
        ),
      ]);
      _pretrained = {
        model,
        vocab: new Map(Object.entries(vocabRaw)),
        labels: labelsRaw as FieldType[],
      };
      await loadLearnedVectors();

      log.info(
        TF_MESSAGES.modelLoaded.bundled(
          labelsRaw.length,
          _pretrained.vocab.size,
          _learnedVectors.length,
        ),
      );
    } catch (err) {
      log.error(TF_MESSAGES.modelLoadFailed.error, err);
      log.warn(TF_MESSAGES.modelLoadFailed.fallback);
    }
  })();

  return _pretrainedLoadPromise;
}

/**
 * Vectorises and caches all entries from the learning-store.
 * Requires the pre-trained vocab to be loaded first.
 */
async function loadLearnedVectors(): Promise<void> {
  if (!_pretrained) return;
  try {
    const entries = await getLearnedEntries();
    _learnedVectors = entries
      .map((entry) => {
        const featureText = buildFeatureText(fromFlatSignals(entry.signals), {
          category: inferCategoryFromType(entry.type),
          language: inferLanguageFromSignals(entry.signals),
        });

        return {
          vector: vectorize(featureText, _pretrained!.vocab),
          type: entry.type,
        };
      })
      .filter((e) => e.vector.some((v) => v > 0));
    log.debug(
      TF_MESSAGES.learnedVectors.summary(
        entries.length,
        _learnedVectors.length,
      ),
    );
  } catch (err) {
    log.warn(TF_MESSAGES.learnedVectors.failed, err);
    _learnedVectors = [];
  }
}

/**
 * Drops the in-memory learned vectors cache so the next classification
 * reloads fresh data from storage.
 */
export function invalidateClassifier(): void {
  const prev = _learnedVectors.length;
  _learnedVectors = [];
  log.debug(TF_MESSAGES.invalidate.dropped(prev));
  if (_pretrained) {
    loadLearnedVectors().catch((err) => {
      log.error(TF_MESSAGES.invalidate.reloadError, err);
    });
  } else {
    log.warn(TF_MESSAGES.invalidate.notLoaded);
  }
}

/**
 * Disposes the TF.js model and all in-memory state, freeing GPU/WASM memory.
 * Call when the classifier will no longer be used in this context
 * (e.g., service worker suspending or extension unloading).
 */
export function disposeTensorflowModel(): void {
  if (_pretrained) {
    _pretrained.model.dispose();
    _pretrained = null;
    _pretrainedLoadPromise = null;
    _learnedVectors = [];
    log.debug("Modelo TF.js e memória associada liberados.");
  }
}

/**
 * Reloads the entire classifier (model + vocab + learned vectors) from storage.
 * Call this after a new model has been trained via the options page.
 */
export async function reloadClassifier(): Promise<void> {
  if (_pretrained) {
    _pretrained.model.dispose();
  }
  _pretrained = null;
  _pretrainedLoadPromise = null;
  _learnedVectors = [];
  await loadPretrainedModel();
  log.info(TF_MESSAGES.reload);
}

// ── Core classification ───────────────────────────────────────────────────────

/**
 * Classify via:
 *   1. Learned vectors (Chrome AI + user corrections) — higher threshold
 *   2. TF.js pre-trained model softmax — TF_THRESHOLD
 *
 * Returns null if signals are empty, the model is not loaded, or the best
 * score is below the threshold.
 */
export function classifyByTfSoft(
  input: string | StructuredSignals,
  context?: StructuredSignalContext,
): { type: FieldType; score: number } | null {
  const featureText =
    typeof input === "string"
      ? buildFeatureText(fromFlatSignals(input), context)
      : buildFeatureText(input, context);

  if (!featureText.trim()) return null;
  if (!_pretrained || !_tfModule) {
    log.warn(TF_MESSAGES.classify.notLoaded(featureText));
    return null;
  }

  const inputVec = vectorize(featureText, _pretrained.vocab);
  if (!inputVec.some((v) => v > 0)) return null;

  // Step 1: Learned vectors (user corrections + Chrome AI)
  if (_learnedVectors.length > 0) {
    let bestLearnedScore = -1;
    let bestLearnedType: FieldType | null = null;
    for (const entry of _learnedVectors) {
      const sim = dotProduct(inputVec, entry.vector);
      if (sim > bestLearnedScore) {
        bestLearnedScore = sim;
        bestLearnedType = entry.type;
      }
    }
    if (bestLearnedScore >= thresholds.learned && bestLearnedType) {
      log.debug(
        TF_MESSAGES.classify.learnedMatch(
          bestLearnedType,
          bestLearnedScore.toFixed(3),
          thresholds.learned,
          featureText,
        ),
      );
      return { type: bestLearnedType, score: bestLearnedScore };
    }
  }

  // Step 2: TF.js pre-trained model
  const { bestIdx, bestScore } = _tfModule.tidy(() => {
    const input = _tfModule!.tensor2d([Array.from(inputVec)]);
    const probs = (_pretrained!.model.predict(input) as Tensor).dataSync();
    let idx = 0;
    let score = -1;
    for (let i = 0; i < probs.length; i++) {
      if (probs[i] > score) {
        score = probs[i];
        idx = i;
      }
    }
    return { bestIdx: idx, bestScore: score };
  });

  if (bestScore < thresholds.model) {
    log.warn(
      TF_MESSAGES.classify.lowScore(
        bestScore.toFixed(3),
        thresholds.model,
        featureText,
        _pretrained.labels[bestIdx],
      ),
    );
    return null;
  }
  return { type: _pretrained.labels[bestIdx], score: bestScore };
}

// ── classifyField (higher-level helper used by dataset/integration & generator) ──

/**
 * Classifies a FormField by building its signals string and running
 * classifyByTfSoft. Falls back to HTML input[type] when the model is
 * not confident enough.
 */
export function classifyField(field: FormField): FieldType {
  const structured = structuredSignalsFromField(field);
  const tfResult = classifyByTfSoft(structured.signals, structured.context);
  const featureText = buildFeatureText(structured.signals, structured.context);

  if (tfResult) {
    log.groupCollapsed(
      TF_MESSAGES.classify.groupLabel(
        tfResult.type,
        tfResult.score.toFixed(3),
        field.selector,
      ),
    );
    log.debug(TF_MESSAGES.classify.featureText, featureText || "(none)");
    log.debug(
      TF_MESSAGES.classify.tfMatch(
        tfResult.type,
        tfResult.score.toFixed(3),
        thresholds.model,
      ),
    );
    log.debug(TF_MESSAGES.classify.field, {
      label: field.label,
      name: field.name,
      id: field.id,
      placeholder: field.placeholder,
    });
    log.groupEnd();
    return tfResult.type;
  }

  const inputType = (field.element as HTMLInputElement).type?.toLowerCase();
  const htmlType: FieldType =
    (TF_CONFIG.htmlTypeFallback[
      inputType as keyof typeof TF_CONFIG.htmlTypeFallback
    ] as FieldType) ?? "unknown";

  log.groupCollapsed(
    TF_MESSAGES.classify.groupLabelFallback(htmlType, field.selector),
  );
  log.debug(TF_MESSAGES.classify.featureText, featureText || "(none)");
  log.debug(TF_MESSAGES.classify.noMatch(inputType));
  log.debug(TF_MESSAGES.classify.field, {
    label: field.label,
    name: field.name,
    id: field.id,
    placeholder: field.placeholder,
  });
  log.groupEnd();

  return htmlType;
}

// ── FieldClassifier implementation ────────────────────────────────────────────

/**
 * TF.js field classifier strategy for the DetectionPipeline.
 * Wraps classifyByTfSoft using the pre-built contextSignals string.
 */
export const tensorflowClassifier: FieldClassifier = {
  name: "tensorflow",
  detect(field: FormField): ClassifierResult | null {
    const structured = structuredSignalsFromField(field);
    const result = classifyByTfSoft(structured.signals, structured.context);
    if (result === null) return null;
    return { type: result.type, confidence: result.score };
  },
};