src/lib/dataset/integration.ts

Total Symbols
9
Lines of Code
229
Avg Complexity
2.0
Avg Coverage
100.0%

Symbols by Kind

function 7
interface 2

All Symbols

Name Kind Visibility Status Lines Signature
buildKeywordsFromDictionary function exported- 51-71 buildKeywordsFromDictionary(): : Record<FieldType, string[]>
AugmentationConfig interface exported- 75-82 interface AugmentationConfig
augmentTrainingSamples function exported- 94-127 augmentTrainingSamples( config: AugmentationConfig = DEFAULT_AUGMENTATION, ): : TrainingSample[]
classifySignals function - 132-145 classifySignals(signals: string): : FieldType
ClassifierReport interface exported- 148-153 interface ClassifierReport
validateClassifier function exported- 159-175 validateClassifier( thresholds: AccuracyThresholds = DEFAULT_THRESHOLDS, ): : ClassifierReport
testClassifier function exported- 180-196 testClassifier( thresholds: AccuracyThresholds = DEFAULT_THRESHOLDS, ): : ClassifierReport
getBalancingNeeds function exported- 204-219 getBalancingNeeds( targetCount = 10, ): : Record<string, { current: number; needed: number }>
getUncoveredDictionaryTypes function exported- 225-228 getUncoveredDictionaryTypes(): : FieldDictionaryEntry[]

Full Source

/**
 * Dataset ↔ Classifier Integration
 *
 * Bridges the dataset (training/validation/test) with the TF.js classifier
 * in `tensorflow-generator.ts`. Provides:
 *
 *   • `buildKeywordsFromDictionary`  — generates FIELD_TYPE_KEYWORDS from the
 *     field dictionary, keeping a single source of truth
 *   • `validateClassifier`           — runs the validation set and returns an
 *     accuracy report
 *   • `testClassifier`               — runs the test/control set (final eval)
 *   • `augmentTrainingSamples`       — produces additional samples through
 *     shuffle / drop / typo augmentation
 *   • `syncLearnedToDataset`         — imports learned entries from
 *     learning-store into the training pipeline
 */

import type { FieldType } from "@/types";
import {
  FIELD_DICTIONARY,
  type FieldDictionaryEntry,
} from "./field-dictionary";
import {
  TRAINING_SAMPLES,
  flattenStructuredSignals,
  getTrainingDistribution,
} from "./training-data";
import type { TrainingSample } from "@/types";
import { evaluateClassifier } from "./validation-data";
import { runTestEvaluation } from "./test-data";
import {
  DEFAULT_THRESHOLDS,
  augmentShuffle,
  augmentDrop,
  augmentTypo,
  checkDatasetHealth,
  type AccuracyThresholds,
  type DatasetHealthReport,
} from "./dataset-config";
import { classifyField } from "@/lib/form/detectors/strategies";
import type { FormField } from "@/types";
import { fromFlatSignals } from "@/lib/shared/structured-signals";

// ── Build keywords from dictionary ──────────────────────────────────────────

/**
 * Generates a `Record<FieldType, string[]>` keyword map directly from the
 * field dictionary. This can replace / supplement the static
 * FIELD_TYPE_KEYWORDS in tensorflow-generator.ts.
 */
export function buildKeywordsFromDictionary(): Record<FieldType, string[]> {
  const result = {} as Record<FieldType, string[]>;

  for (const entry of FIELD_DICTIONARY) {
    const keywords: string[] = [];

    // Derive keywords from tags (tags are the classification hints)
    if (entry.tags.length > 0) {
      keywords.push(...entry.tags);
    }

    // Derive keywords from entry type itself (e.g. "cpf" → ["cpf"])
    if (!keywords.includes(entry.type)) {
      keywords.push(entry.type);
    }

    result[entry.type] = [...new Set(keywords)]; // dedupe
  }

  return result;
}

// ── Augmentation ────────────────────────────────────────────────────────────

export interface AugmentationConfig {
  /** How many augmented samples to create per original sample */
  multiplier: number;
  /** Which augmentation strategies to apply */
  strategies: Array<"shuffle" | "drop" | "typo">;
  /** Drop rate for the "drop" strategy (default 0.2) */
  dropRate?: number;
}

const DEFAULT_AUGMENTATION: AugmentationConfig = {
  multiplier: 2,
  strategies: ["shuffle", "drop", "typo"],
  dropRate: 0.2,
};

/**
 * Generate augmented training samples by applying random transformations.
 * Returns only the new augmented samples (not the originals).
 */
export function augmentTrainingSamples(
  config: AugmentationConfig = DEFAULT_AUGMENTATION,
): TrainingSample[] {
  const augmented: TrainingSample[] = [];
  const strategyFns: Record<string, (s: string) => string> = {
    shuffle: augmentShuffle,
    drop: (s: string) => augmentDrop(s, config.dropRate ?? 0.2),
    typo: augmentTypo,
  };

  for (const sample of TRAINING_SAMPLES) {
    for (let i = 0; i < config.multiplier; i++) {
      // Pick a random strategy
      const strategy =
        config.strategies[Math.floor(Math.random() * config.strategies.length)];
      const fn = strategyFns[strategy];
      if (fn) {
        const signalText = flattenStructuredSignals(sample.signals);
        augmented.push({
          signals: fromFlatSignals(fn(signalText)),
          category: sample.category,
          type: sample.type,
          source: "augmented",
          domain: sample.domain,
          difficulty: sample.difficulty,
          language: sample.language,
          domFeatures: sample.domFeatures,
        });
      }
    }
  }

  return augmented;
}

// ── Validation & test runners ───────────────────────────────────────────────

/** Wraps classification for the validation/test evaluators */
function classifySignals(signals: string): FieldType {
  const mockField: FormField = {
    element: document.createElement("input"),
    selector: "",
    category: "unknown",
    fieldType: "unknown",
    label: signals,
    name: "",
    id: "",
    placeholder: "",
    required: false,
  };
  return classifyField(mockField);
}

/** Aggregated classifier accuracy report with health check. */
export interface ClassifierReport {
  globalAccuracy: number;
  passesGlobalThreshold: boolean;
  failingTypes: string[];
  health: DatasetHealthReport;
}

/**
 * Run the validation set against the current classifier.
 * Returns an actionable report.
 */
export function validateClassifier(
  thresholds: AccuracyThresholds = DEFAULT_THRESHOLDS,
): ClassifierReport {
  const result = evaluateClassifier(classifySignals);
  const health = checkDatasetHealth();

  const failingTypes = Object.entries(result.perType)
    .filter(([, v]) => v.accuracy < thresholds.perTypeMin)
    .map(([type]) => type);

  return {
    globalAccuracy: result.globalAccuracy,
    passesGlobalThreshold: result.globalAccuracy >= thresholds.globalMin,
    failingTypes,
    health,
  };
}

/**
 * Run the final test/control set — should only be called once after all tuning.
 */
export function testClassifier(
  thresholds: AccuracyThresholds = DEFAULT_THRESHOLDS,
): ClassifierReport {
  const result = runTestEvaluation(classifySignals);
  const health = checkDatasetHealth();

  const failingTypes = Object.entries(result.perType)
    .filter(([, v]) => v.accuracy < thresholds.perTypeMin)
    .map(([type]) => type);

  return {
    globalAccuracy: result.globalAccuracy,
    passesGlobalThreshold: result.globalAccuracy >= thresholds.globalMin,
    failingTypes,
    health,
  };
}

// ── Dataset balance helpers ─────────────────────────────────────────────────

/**
 * Returns the minimum number of extra samples needed per type to achieve
 * a balanced dataset (all types have `targetCount` samples).
 */
export function getBalancingNeeds(
  targetCount = 10,
): Record<string, { current: number; needed: number }> {
  const dist = getTrainingDistribution();
  const result: Record<string, { current: number; needed: number }> = {};

  for (const entry of FIELD_DICTIONARY) {
    const current = dist[entry.type] || 0;
    result[entry.type] = {
      current,
      needed: Math.max(0, targetCount - current),
    };
  }

  return result;
}

/**
 * Get a summary of dictionary entries that don't yet have matching
 * training samples. Useful for identifying gaps in dataset coverage.
 */
export function getUncoveredDictionaryTypes(): FieldDictionaryEntry[] {
  const dist = getTrainingDistribution();
  return FIELD_DICTIONARY.filter((e) => !(e.type in dist));
}