diff --git a/.rubocop.yml b/.rubocop.yml index dfadfcd..fb468d4 100644 --- a/.rubocop.yml +++ b/.rubocop.yml @@ -29,16 +29,18 @@ Metrics/BlockLength: Exclude: - 'test/**/*' - 'lib/classifier/extensions/vector.rb' + - 'lib/classifier/cli.rb' -# Allow longer methods in complex algorithms (SVD, etc.) +# Allow longer methods in complex algorithms (SVD, etc.) and CLI Metrics/MethodLength: Max: 25 Exclude: - 'test/**/*' - 'lib/classifier/extensions/vector.rb' - 'lib/classifier/lsi/content_node.rb' + - 'lib/classifier/cli.rb' -# Allow higher complexity for mathematical algorithms +# Allow higher complexity for mathematical algorithms and CLI Metrics/AbcSize: Max: 30 Exclude: @@ -46,25 +48,29 @@ Metrics/AbcSize: - 'lib/classifier/extensions/vector.rb' - 'lib/classifier/lsi.rb' - 'lib/classifier/lsi/content_node.rb' + - 'lib/classifier/cli.rb' Metrics/CyclomaticComplexity: Max: 10 Exclude: - 'lib/classifier/extensions/vector.rb' - 'lib/classifier/lsi/content_node.rb' + - 'lib/classifier/cli.rb' Metrics/PerceivedComplexity: Max: 10 Exclude: - 'lib/classifier/extensions/vector.rb' - 'lib/classifier/lsi/content_node.rb' + - 'lib/classifier/cli.rb' -# Class length limits - algorithms and tests can be longer +# Class length limits - algorithms, tests and CLI can be longer Metrics/ClassLength: Max: 250 Exclude: - 'test/**/*' - 'lib/classifier/lsi.rb' + - 'lib/classifier/cli.rb' # SV_decomp is a standard algorithm name Naming/MethodName: diff --git a/Gemfile.lock b/Gemfile.lock index 6d5506d..9df8beb 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - classifier (2.2.0) + classifier (2.3.0) fast-stemmer (~> 1.0) matrix mutex_m (~> 0.2) @@ -23,12 +23,17 @@ GEM securerandom (>= 0.3) tzinfo (~> 2.0, >= 2.0.5) uri (>= 0.13.1) + addressable (2.8.8) + public_suffix (>= 2.0.2, < 8.0) ast (2.4.3) base64 (0.3.0) bigdecimal (4.0.1) cgi (0.5.1) concurrent-ruby (1.3.6) connection_pool (3.0.2) + crack (1.0.1) + bigdecimal + rexml csv (3.3.5) date (3.5.1) docile (1.4.1) @@ -39,6 +44,7 @@ GEM ffi (1.17.2-arm64-darwin) ffi (1.17.2-x86_64-linux-gnu) fileutils (1.8.0) + hashdiff (1.2.1) i18n (1.14.8) concurrent-ruby (~> 1.0) json (2.18.0) @@ -61,6 +67,7 @@ GEM psych (5.3.1) date stringio + public_suffix (7.0.0) racc (1.8.1) rainbow (3.1.1) rake (13.3.1) @@ -80,6 +87,7 @@ GEM psych (>= 4.0.0) tsort regexp_parser (2.11.3) + rexml (3.4.4) rubocop (1.82.1) json (~> 2.3) language_server-protocol (~> 3.17.0.2) @@ -134,6 +142,10 @@ GEM unicode-emoji (~> 4.1) unicode-emoji (4.2.0) uri (1.1.1) + webmock (3.26.1) + addressable (>= 2.8.0) + crack (>= 0.3.2) + hashdiff (>= 0.4.0, < 2.0.0) PLATFORMS arm64-darwin-22 @@ -156,6 +168,7 @@ DEPENDENCIES rubocop-minitest simplecov steep + webmock BUNDLED WITH 4.0.3 diff --git a/Steepfile b/Steepfile index 1a72a0a..d5120ce 100644 --- a/Steepfile +++ b/Steepfile @@ -5,6 +5,12 @@ target :lib do check 'lib' + # Stdlib dependencies for CLI + library 'fileutils' + library 'uri' + library 'net-http' + library 'json' + # Strict mode: report methods without type annotations configure_code_diagnostics(D::Ruby.strict) diff --git a/classifier.gemspec b/classifier.gemspec index fe9ea07..9ebfbbe 100644 --- a/classifier.gemspec +++ b/classifier.gemspec @@ -1,6 +1,8 @@ +require_relative 'lib/classifier/version' + Gem::Specification.new do |s| s.name = 'classifier' - s.version = '2.2.0' + s.version = Classifier::VERSION s.summary = 'Text classification with Bayesian, LSI, Logistic Regression, kNN, and TF-IDF vectorization.' s.description = 'A Ruby library for text classification featuring Naive Bayes, LSI (Latent Semantic Indexing), ' \ 'Logistic Regression, and k-Nearest Neighbors classifiers. Includes TF-IDF vectorization, ' \ @@ -16,7 +18,9 @@ Gem::Specification.new do |s| 'changelog_uri' => 'https://github.com/cardmagic/classifier/releases' } s.required_ruby_version = '>= 3.1' - s.files = Dir['{lib,sig}/**/*.{rb,rbs}', 'ext/**/*.{c,h,rb}', 'bin/*', 'LICENSE', '*.md', 'test/*'] + s.files = Dir['{lib,sig,exe}/**/*.{rb,rbs}', 'ext/**/*.{c,h,rb}', 'exe/*', 'bin/*', 'LICENSE', '*.md', 'test/*'] + s.bindir = 'exe' + s.executables = ['classifier'] s.extensions = ['ext/classifier/extconf.rb'] s.license = 'LGPL' @@ -28,4 +32,5 @@ Gem::Specification.new do |s| s.add_development_dependency 'rbs-inline' s.add_development_dependency 'rdoc' s.add_development_dependency 'rake-compiler' + s.add_development_dependency 'webmock' end diff --git a/exe/classifier b/exe/classifier new file mode 100755 index 0000000..cf73b86 --- /dev/null +++ b/exe/classifier @@ -0,0 +1,9 @@ +#!/usr/bin/env ruby +require 'classifier/cli' + +result = Classifier::CLI.new(ARGV).run + +warn result[:error] unless result[:error].empty? +puts result[:output] unless result[:output].empty? + +exit result[:exit_code] diff --git a/lib/classifier.rb b/lib/classifier.rb index 77cbd93..369dc64 100644 --- a/lib/classifier.rb +++ b/lib/classifier.rb @@ -25,6 +25,7 @@ # License:: LGPL require 'rubygems' +require 'classifier/version' require 'classifier/errors' require 'classifier/storage' require 'classifier/streaming' diff --git a/lib/classifier/cli.rb b/lib/classifier/cli.rb new file mode 100644 index 0000000..33d792c --- /dev/null +++ b/lib/classifier/cli.rb @@ -0,0 +1,880 @@ +# rbs_inline: enabled + +require 'json' +require 'optparse' +require 'net/http' +require 'uri' +require 'fileutils' +require 'classifier' + +module Classifier + class CLI + # @rbs @args: Array[String] + # @rbs @stdin: String? + # @rbs @options: Hash[Symbol, untyped] + # @rbs @output: Array[String] + # @rbs @error: Array[String] + # @rbs @exit_code: Integer + # @rbs @parser: OptionParser + + CLASSIFIER_TYPES = { + 'bayes' => :bayes, + 'lsi' => :lsi, + 'knn' => :knn, + 'lr' => :logistic_regression, + 'logistic_regression' => :logistic_regression + }.freeze + + DEFAULT_REGISTRY = ENV.fetch('CLASSIFIER_REGISTRY', 'cardmagic/classifier-models') #: String + CACHE_DIR = ENV.fetch('CLASSIFIER_CACHE', File.expand_path('~/.classifier')) #: String + + def initialize(args, stdin: nil) + @args = args.dup + @stdin = stdin + @options = { + model: ENV.fetch('CLASSIFIER_MODEL', './classifier.json'), + type: ENV.fetch('CLASSIFIER_TYPE', 'bayes'), + probabilities: false, + quiet: false, + count: 10, + k: 5, + weighted: false, + learning_rate: nil, + regularization: nil, + max_iterations: nil, + remote: nil, + output_path: nil + } + @output = [] #: Array[String] + @error = [] #: Array[String] + @exit_code = 0 + end + + def run + parse_options + execute_command + { output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code } + rescue OptionParser::InvalidOption, OptionParser::MissingArgument, OptionParser::InvalidArgument => e + @error << "Error: #{e.message}" + @exit_code = 2 + { output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code } + rescue StandardError => e + @error << "Error: #{e.message}" + @exit_code = 1 + { output: @output.join("\n"), error: @error.join("\n"), exit_code: @exit_code } + end + + private + + def parse_options + @parser = OptionParser.new do |opts| + opts.banner = 'Usage: classifier [options] [command] [arguments]' + opts.separator '' + opts.separator 'Commands:' + opts.separator ' train [files...] Train a category from files or stdin' + opts.separator ' info Show model information' + opts.separator ' fit Fit the model (logistic regression)' + opts.separator ' search Semantic search (LSI only)' + opts.separator ' related Find related documents (LSI only)' + opts.separator ' models [registry] List models in registry' + opts.separator ' pull Download model from registry' + opts.separator ' push Contribute model to registry' + opts.separator ' Classify text (default action)' + opts.separator '' + opts.separator 'Options:' + + opts.on('-f', '--file FILE', 'Model file (default: ./classifier.json)') do |file| + @options[:model] = file + end + + opts.on('-m', '--model TYPE', 'Classifier model: bayes, lsi, knn, lr (default: bayes)') do |type| + unless CLASSIFIER_TYPES.key?(type) + raise OptionParser::InvalidArgument, "Unknown classifier model: #{type}. Valid models: #{CLASSIFIER_TYPES.keys.join(', ')}" + end + + @options[:type] = type + end + + opts.on('-r', '--remote MODEL', 'Use remote model: name or @user/repo:name') do |model| + @options[:remote] = model + end + + opts.on('-o', '--output FILE', 'Output path for pull command') do |file| + @options[:output_path] = file + end + + opts.on('-p', 'Show probabilities') do + @options[:probabilities] = true + end + + opts.on('-n', '--count N', Integer, 'Number of results for search/related (default: 10)') do |n| + @options[:count] = n + end + + opts.on('-k', '--neighbors N', Integer, 'Number of neighbors for KNN (default: 5)') do |n| + @options[:k] = n + end + + opts.on('--weighted', 'Use distance-weighted voting for KNN') do + @options[:weighted] = true + end + + opts.on('--learning-rate N', Float, 'Learning rate for logistic regression (default: 0.1)') do |n| + @options[:learning_rate] = n + end + + opts.on('--regularization N', Float, 'L2 regularization for logistic regression (default: 0.01)') do |n| + @options[:regularization] = n + end + + opts.on('--max-iterations N', Integer, 'Max iterations for logistic regression (default: 100)') do |n| + @options[:max_iterations] = n + end + + opts.on('-q', 'Quiet mode') do + @options[:quiet] = true + end + + opts.on('--local', 'List locally cached models (for models command)') do + @options[:local] = true + end + + opts.on('-v', '--version', 'Show version') do + @output << Classifier::VERSION + @exit_code = 0 + throw :done + end + + opts.on('-h', '--help', 'Show help') do + @output << opts.to_s + @exit_code = 0 + throw :done + end + end + + catch(:done) do + @parser.parse!(@args) + end + end + + def execute_command + return if @exit_code != 0 || @output.any? + + command = @args.first + + case command + when 'train' + command_train + when 'info' + command_info + when 'fit' + command_fit + when 'search' + command_search + when 'related' + command_related + when 'models' + command_models + when 'pull' + command_pull + when 'push' + command_push + else + command_classify + end + end + + def command_train + @args.shift # remove 'train' + category = @args.shift + + unless category + @error << 'Error: category required for train command' + @exit_code = 2 + return + end + + classifier = load_or_create_classifier + + if classifier.is_a?(LSI) && @args.any? + train_lsi_from_files(classifier, category, @args) + save_classifier(classifier) + return + end + + text = read_training_input + if text.empty? + @error << 'Error: no training data provided' + @exit_code = 2 + return + end + + train_classifier(classifier, category, text) + save_classifier(classifier) + end + + def command_info + unless File.exist?(@options[:model]) + @error << "Error: model not found at #{@options[:model]}" + @exit_code = 1 + return + end + + classifier = load_classifier + info = build_model_info(classifier) + @output << JSON.pretty_generate(info) + end + + def build_model_info(classifier) + info = { file: @options[:model], type: classifier_type_name(classifier) } + add_common_info(info, classifier) + add_classifier_specific_info(info, classifier) + info + end + + def add_common_info(info, classifier) + info[:categories] = classifier.categories.map(&:to_s) if classifier.respond_to?(:categories) + info[:training_count] = classifier.training_count if classifier.respond_to?(:training_count) + info[:vocab_size] = classifier.vocab_size if classifier.respond_to?(:vocab_size) + info[:fitted] = classifier.fitted? if classifier.respond_to?(:fitted?) + end + + def add_classifier_specific_info(info, classifier) + case classifier + when Bayes then add_bayes_info(info, classifier) + when LSI then add_lsi_info(info, classifier) + when KNN then add_knn_info(info, classifier) + end + end + + def add_bayes_info(info, classifier) + categories_data = classifier.instance_variable_get(:@categories) + info[:category_stats] = classifier.categories.to_h do |cat| + cat_data = categories_data[cat.to_sym] || {} + [cat.to_s, { unique_words: cat_data.size, total_words: cat_data.values.sum }] + end + end + + def add_lsi_info(info, classifier) + info[:documents] = classifier.items.size + info[:items] = classifier.items + categories = classifier.items.map { |item| classifier.categories_for(item) }.flatten.uniq + info[:categories] = categories.map(&:to_s) unless categories.empty? + end + + def add_knn_info(info, classifier) + data = classifier.instance_variable_get(:@data) || [] + info[:documents] = data.size + categories = data.map { |d| d[:category] }.uniq + info[:categories] = categories.map(&:to_s) unless categories.empty? + end + + def command_fit + unless File.exist?(@options[:model]) + @error << "Error: model not found at #{@options[:model]}" + @exit_code = 1 + return + end + + classifier = load_classifier + + unless classifier.respond_to?(:fit) + @output << 'Model does not require fitting' unless @options[:quiet] + return + end + + classifier.fit + save_classifier(classifier) + @output << 'Model fitted successfully' unless @options[:quiet] + end + + def command_search + @args.shift # remove 'search' + + unless File.exist?(@options[:model]) + @error << "Error: model not found at #{@options[:model]}" + @exit_code = 1 + return + end + + classifier = load_classifier + + unless classifier.is_a?(LSI) + @error << 'Error: search requires LSI model (use -t lsi)' + @exit_code = 1 + return + end + + query = @args.join(' ') + query = read_stdin_line if query.empty? + + if query.empty? + @error << 'Error: search query required' + @exit_code = 2 + return + end + + results = classifier.search(query, @options[:count]) + results.each do |item| + score = classifier.proximity_norms_for_content(query).find { |i, _| i == item }&.last || 0 + @output << "#{item}:#{format('%.2f', score)}" + end + end + + def command_related + @args.shift # remove 'related' + item = @args.shift + + unless item + @error << 'Error: item required for related command' + @exit_code = 2 + return + end + + unless File.exist?(@options[:model]) + @error << "Error: model not found at #{@options[:model]}" + @exit_code = 1 + return + end + + classifier = load_classifier + + unless classifier.is_a?(LSI) + @error << 'Error: related requires LSI model (use -t lsi)' + @exit_code = 1 + return + end + + unless classifier.items.include?(item) + @error << "Error: item not found in model: #{item}" + @exit_code = 1 + return + end + + results = classifier.find_related(item, @options[:count]) + results.each do |related_item| + scores = classifier.proximity_array_for_content(item) + score = scores.find { |i, _| i == related_item }&.last || 0 + @output << "#{related_item}:#{format('%.2f', score)}" + end + end + + def command_models + @args.shift # remove 'models' + + if @options[:local] + list_local_models + else + list_remote_models + end + end + + def list_remote_models + registry_arg = @args.shift + registry = parse_registry(registry_arg) || DEFAULT_REGISTRY + index = fetch_registry_index(registry) + + return if @exit_code != 0 + + if index['models'].empty? + @output << 'No models found in registry' + return + end + + index['models'].each do |name, info| + type = info['type'] || 'unknown' + size = info['size'] || 'unknown' + desc = info['description'] || '' + @output << format('%-20s %s (%s, %s)', name: name, desc: desc.slice(0, 40), type: type, size: size) + end + end + + def list_local_models + models_dir = File.join(CACHE_DIR, 'models') + + unless Dir.exist?(models_dir) + @output << 'No local models found' + return + end + + # Find models from default registry + default_models = Dir.glob(File.join(models_dir, '*.json')).map do |path| + { name: File.basename(path, '.json'), registry: nil, path: path } + end + + # Find models from custom registries (@user/repo structure) + custom_models = Dir.glob(File.join(models_dir, '@*', '*', '*.json')).map do |path| + # Extract registry from path: .../models/@user/repo/model.json + repo_dir = File.dirname(path) + user_dir = File.dirname(repo_dir) + registry = "#{File.basename(user_dir).delete_prefix('@')}/#{File.basename(repo_dir)}" + { name: File.basename(path, '.json'), registry: registry, path: path } + end + + models = default_models + custom_models #: Array[{name: String, registry: String?, path: String}] + + if models.empty? + @output << 'No local models found' + return + end + + models.each do |model| + info = load_model_info(model[:path]) + type = info['type'] || 'unknown' + display_name = model[:registry] ? "@#{model[:registry]}:#{model[:name]}" : model[:name] + size = File.size(model[:path]) + @output << format('%-30s (%s, %s)', name: display_name, type: type, size: human_size(size)) + end + end + + def load_model_info(path) + JSON.parse(File.read(path)) + rescue JSON::ParserError + {} + end + + def human_size(bytes) + units = %w[B KB MB GB] + unit_index = 0 + size = bytes.to_f + + while size >= 1024 && unit_index < units.length - 1 + size /= 1024 + unit_index += 1 + end + + format('%.1f %s', size: size, unit: units[unit_index]) + end + + def command_pull + @args.shift # remove 'pull' + model_spec = @args.shift + + unless model_spec + @error << 'Error: model name required for pull command' + @exit_code = 2 + return + end + + registry, model_name = parse_model_spec(model_spec) + registry ||= DEFAULT_REGISTRY + + if model_name.nil? + pull_all_models(registry) + else + pull_single_model(registry, model_name) + end + end + + def pull_single_model(registry, model_name) + index = fetch_registry_index(registry) + return if @exit_code != 0 + + model_info = index['models'][model_name] + unless model_info + @error << "Error: model '#{model_name}' not found in registry #{registry}" + @exit_code = 1 + return + end + + file_path = model_info['file'] || "models/#{model_name}.json" + output_path = @options[:output_path] || cache_path_for(registry, model_name) + + @output << "Downloading #{model_name} from #{registry}..." unless @options[:quiet] + + content = fetch_github_file(registry, file_path) + return if @exit_code != 0 + + FileUtils.mkdir_p(File.dirname(output_path)) + File.write(output_path, content) + + @output << "Saved to #{output_path}" unless @options[:quiet] + end + + def pull_all_models(registry) + index = fetch_registry_index(registry) + return if @exit_code != 0 + + if index['models'].empty? + @output << 'No models found in registry' + return + end + + @output << "Downloading #{index['models'].size} models from #{registry}..." unless @options[:quiet] + + index['models'].each_key do |model_name| + pull_single_model(registry, model_name) + break if @exit_code != 0 + end + end + + def command_push + @args.shift # remove 'push' + + @output << 'To contribute a model to the registry:' + @output << '' + @output << '1. Fork https://github.com/cardmagic/classifier-models' + @output << '2. Add your model to the models/ directory' + @output << '3. Update models.json with your model metadata' + @output << '4. Create a pull request' + @output << '' + @output << 'Or use the GitHub CLI:' + @output << '' + @output << ' gh repo fork cardmagic/classifier-models --clone' + @output << ' cp ./classifier.json classifier-models/models/my-model.json' + @output << ' # Edit classifier-models/models.json to add your model' + @output << ' cd classifier-models && gh pr create' + end + + def command_classify + text = @args.join(' ') + + if @options[:remote] + classify_with_remote(text) + return + end + + if text.empty? && ($stdin.tty? || @stdin.nil?) && !File.exist?(@options[:model]) + show_getting_started + return + end + + unless File.exist?(@options[:model]) + @error << "Error: model not found at #{@options[:model]}" + @exit_code = 1 + return + end + + classifier = load_classifier + + if text.empty? + lines = read_stdin_lines + return show_model_usage(classifier) if lines.empty? + + lines.each { |line| classify_and_output(classifier, line) } + else + classify_and_output(classifier, text) + end + end + + def classify_with_remote(text) + registry, model_name = parse_model_spec(@options[:remote]) + registry ||= DEFAULT_REGISTRY + + unless model_name + @error << 'Error: model name required for -r option' + @exit_code = 2 + return + end + + cached_path = cache_path_for(registry, model_name) + + unless File.exist?(cached_path) + pull_single_model(registry, model_name) + return if @exit_code != 0 + end + + original_model = @options[:model] + @options[:model] = cached_path + + begin + classifier = load_classifier + + if text.empty? + lines = read_stdin_lines + return show_model_usage(classifier) if lines.empty? + + lines.each { |line| classify_and_output(classifier, line) } + else + classify_and_output(classifier, text) + end + ensure + @options[:model] = original_model + end + end + + # @rbs (untyped) -> void + def show_model_usage(classifier) + type = classifier_type_name(classifier) + cats = classifier.categories.map(&:to_s).map(&:downcase) + first_cat = cats.first || 'category' + + @output << "Model: #{@options[:model]} (#{type})" + @output << "Categories: #{cats.join(', ')}" + @output << '' + @output << 'Classify text:' + @output << '' + @output << " classifier 'text to classify'" + @output << " echo 'text to classify' | classifier" + @output << '' + @output << 'Train more data:' + @output << '' + @output << " echo 'new example text' | classifier train #{first_cat}" + @output << " classifier train #{first_cat} file1.txt file2.txt" + @output << '' + @output << 'Other commands:' + @output << '' + @output << ' classifier info Show model details (JSON)' + end + + def classify_and_output(classifier, text) + return if text.strip.empty? + + if classifier.is_a?(LogisticRegression) && !classifier.fitted? + raise StandardError, "Model not fitted. Run 'classifier fit' after training." + end + + if @options[:probabilities] + probs = get_probabilities(classifier, text) + formatted = probs.map { |cat, prob| "#{cat.downcase}:#{format('%.2f', prob)}" }.join(' ') + @output << formatted + else + result = classifier.classify(text) + @output << result.downcase + end + end + + def get_probabilities(classifier, text) + if classifier.respond_to?(:probabilities) + classifier.probabilities(text) + elsif classifier.respond_to?(:classifications) + scores = classifier.classifications(text) + normalize_scores(scores) + else + { classifier.classify(text) => 1.0 } + end + end + + def normalize_scores(scores) + max_score = scores.values.max + exp_scores = scores.transform_values { |s| Math.exp(s - max_score) } + total = exp_scores.values.sum.to_f + exp_scores.transform_values { |s| (s / total).to_f } + end + + def load_or_create_classifier + if File.exist?(@options[:model]) + load_classifier + else + create_classifier + end + end + + def load_classifier + json = File.read(@options[:model]) + data = JSON.parse(json) + type = data['type'] + + case type + when 'bayes' + Bayes.from_json(data) + when 'lsi' + LSI.from_json(data) + when 'knn' + KNN.from_json(data) + when 'logistic_regression' + LogisticRegression.from_json(data) + else + raise "Unknown classifier type in model: #{type}" + end + end + + def create_classifier + type = CLASSIFIER_TYPES[@options[:type]] || :bayes + + case type + when :lsi + LSI.new(auto_rebuild: true) + when :knn + KNN.new(k: @options[:k], weighted: @options[:weighted]) + when :logistic_regression + lr_opts = {} #: Hash[Symbol, untyped] + lr_opts[:learning_rate] = @options[:learning_rate] if @options[:learning_rate] + lr_opts[:regularization] = @options[:regularization] if @options[:regularization] + lr_opts[:max_iterations] = @options[:max_iterations] if @options[:max_iterations] + LogisticRegression.new(**lr_opts) + else # :bayes or unknown defaults to Bayes + Bayes.new + end + end + + def train_classifier(classifier, category, text) + case classifier + when Bayes, LogisticRegression + classifier.add_category(category) unless classifier.categories.include?(category) + text.each_line { |line| classifier.train(category, line.strip) unless line.strip.empty? } + when LSI + text.each_line do |line| + next if line.strip.empty? + + classifier.add_item(line.strip, category.to_sym) + end + when KNN + text.each_line do |line| + next if line.strip.empty? + + classifier.add(category.to_sym => line.strip) + end + end + end + + def train_lsi_from_files(classifier, category, files) + files.each do |file| + content = File.read(file) + classifier.add_item(file, category.to_sym) { content } + end + end + + def save_classifier(classifier) + classifier.storage = Storage::File.new(path: @options[:model]) + classifier.save + end + + def classifier_type_name(classifier) + case classifier + when Bayes then 'bayes' + when LSI then 'lsi' + when KNN then 'knn' + when LogisticRegression then 'logistic_regression' + else 'unknown' + end + end + + def read_training_input + if @args.any? + @args.map { |file| File.read(file) }.join("\n") + else + read_stdin + end + end + + def read_stdin + @stdin || ($stdin.tty? ? '' : $stdin.read) + end + + def read_stdin_line + (@stdin || ($stdin.tty? ? '' : $stdin.read)).to_s.strip + end + + def read_stdin_lines + read_stdin.to_s.split("\n").map(&:strip).reject(&:empty?) + end + + # @rbs () -> void + def show_getting_started + @output << 'Classifier - Text classification from the command line' + @output << '' + @output << 'Get started by training some categories:' + @output << '' + @output << ' # Train from files' + @output << ' classifier train spam spam_emails/*.txt' + @output << ' classifier train ham good_emails/*.txt' + @output << '' + @output << ' # Train from stdin' + @output << " echo 'buy viagra now free pills cheap meds' | classifier train spam" + @output << " echo 'meeting scheduled for tomorrow to discuss project' | classifier train ham" + @output << '' + @output << 'Then classify text:' + @output << '' + @output << " classifier 'free money buy now'" + @output << " classifier 'meeting postponed to friday'" + @output << '' + @output << 'Use LSI for semantic search:' + @output << '' + @output << " echo 'ruby is a dynamic programming language' | classifier train docs -m lsi" + @output << " echo 'python is great for data science' | classifier train docs -m lsi" + @output << " classifier search 'programming'" + @output << '' + @output << 'Options:' + @output << ' -f FILE Model file (default: ./classifier.json)' + @output << ' -m TYPE Model type: bayes, lsi, knn, lr (default: bayes)' + @output << ' -r MODEL Use remote model from registry' + @output << ' -p Show probabilities' + @output << '' + @output << 'Use pre-trained models:' + @output << '' + @output << ' classifier models List available models' + @output << ' classifier pull sentiment Download a model' + @output << " classifier -r sentiment 'I love this!' Classify with remote model" + @output << '' + @output << 'Run "classifier --help" for full usage.' + end + + # Parse @user/repo format to extract registry + # @rbs (String?) -> String? + def parse_registry(arg) + return nil unless arg + return nil unless arg.start_with?('@') + + # @user/repo format + arg[1..] # Remove @ prefix + end + + # Parse model spec: name, @user/repo:name, or @user/repo (for all models) + # Returns [registry, model_name] where model_name is nil if pulling all + # @rbs (String) -> [String?, String?] + def parse_model_spec(spec) + if spec.start_with?('@') + # @user/repo:model or @user/repo + rest = spec[1..] || '' + if spec.include?(':') + parts = rest.split(':', 2) + [parts[0], parts[1]] + else + # @user/repo - pull all models from registry + [rest, nil] + end + else + # Just a model name from default registry + [nil, spec] + end + end + + # Get cache path for a model + # @rbs (String, String) -> String + def cache_path_for(registry, model_name) + if registry == DEFAULT_REGISTRY + File.join(CACHE_DIR, 'models', "#{model_name}.json") + else + File.join(CACHE_DIR, 'models', "@#{registry}", "#{model_name}.json") + end + end + + # Fetch models.json index from a registry + # @rbs (String) -> Hash[String, untyped] + def fetch_registry_index(registry) + content = fetch_github_file(registry, 'models.json') + return { 'models' => {} } if @exit_code != 0 + + JSON.parse(content) + rescue JSON::ParserError => e + @error << "Error: invalid models.json in registry: #{e.message}" + @exit_code = 1 + { 'models' => {} } + end + + # Fetch a file from GitHub raw content + # @rbs (String, String) -> String + def fetch_github_file(registry, file_path) + url = "https://raw.githubusercontent.com/#{registry}/main/#{file_path}" + uri = URI.parse(url) + + response = Net::HTTP.get_response(uri) + + unless response.is_a?(Net::HTTPSuccess) + # Try master branch if main fails + url = "https://raw.githubusercontent.com/#{registry}/master/#{file_path}" + uri = URI.parse(url) + response = Net::HTTP.get_response(uri) + end + + unless response.is_a?(Net::HTTPSuccess) + @error << "Error: failed to fetch #{file_path} from #{registry} (#{response.code})" + @exit_code = 1 + return '' + end + + response.body + end + end +end diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index 314904f..a169f85 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -62,8 +62,6 @@ def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, tolerance: DEFAULT_TOLERANCE) super() categories = categories.flatten - raise ArgumentError, 'At least two categories required' if categories.size < 2 - @categories = categories.map { |c| c.to_s.prepare_category_name } @weights = @categories.to_h { |c| [c, {}] } @bias = @categories.to_h { |c| [c, 0.0] } @@ -99,6 +97,7 @@ def train(category = nil, text = nil, **categories) def fit synchronize do return self if @training_data.empty? + raise ArgumentError, 'At least two categories required for fitting' if @categories.size < 2 optimize_weights @fitted = true @@ -122,13 +121,14 @@ def classify(text) # Returns probability distribution across all categories. # Probabilities are well-calibrated (unlike Naive Bayes). + # Raises NotFittedError if model has not been fitted. # # classifier.probabilities("Buy now!") # # => {"Spam" => 0.92, "Ham" => 0.08} # # @rbs (String) -> Hash[String, Float] def probabilities(text) - fit unless @fitted + raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted features = text.word_hash synchronize do @@ -137,10 +137,11 @@ def probabilities(text) end # Returns log-odds scores for each category (before softmax). + # Raises NotFittedError if model has not been fitted. # # @rbs (String) -> Hash[String, Float] def classifications(text) - fit unless @fitted + raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted features = text.word_hash synchronize do @@ -173,6 +174,23 @@ def categories synchronize { @categories.map(&:to_s) } end + # Adds a new category to the classifier. + # Allows dynamic category creation for CLI and incremental training. + # + # @rbs (String | Symbol) -> void + def add_category(category) + cat = category.to_s.prepare_category_name + synchronize do + return if @categories.include?(cat) + + @categories << cat + @weights[cat] = {} + @bias[cat] = 0.0 + @fitted = false + @dirty = true + end + end + # Returns true if the model has been fitted. # # @rbs () -> bool @@ -205,11 +223,10 @@ def respond_to_missing?(name, include_private = false) end # Returns a hash representation of the classifier state. + # Does NOT auto-fit; saves current state including unfitted models. # # @rbs (?untyped) -> Hash[Symbol, untyped] def as_json(_options = nil) - fit unless @fitted - { version: 1, type: 'logistic_regression', @@ -217,10 +234,12 @@ def as_json(_options = nil) weights: @weights.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) }, bias: @bias.transform_keys(&:to_s), vocabulary: @vocabulary.keys.map(&:to_s), + training_data: @training_data.map { |d| { category: d[:category].to_s, features: d[:features].transform_keys(&:to_s) } }, learning_rate: @learning_rate, regularization: @regularization, max_iterations: @max_iterations, - tolerance: @tolerance + tolerance: @tolerance, + fitted: @fitted } end @@ -546,26 +565,29 @@ def restore_from_json(json) def restore_state(data, categories) mu_initialize @categories = categories + restore_weights_and_bias(data) + restore_hyperparameters(data) + @fitted = data.fetch('fitted', true) + @dirty = false + @storage = nil + end + + def restore_weights_and_bias(data) @weights = {} @bias = {} - - data['weights'].each do |cat, words| - @weights[cat.to_sym] = words.transform_keys(&:to_sym).transform_values(&:to_f) - end - - data['bias'].each do |cat, value| - @bias[cat.to_sym] = value.to_f + data['weights'].each { |cat, words| @weights[cat.to_sym] = words.transform_keys(&:to_sym).transform_values(&:to_f) } + data['bias'].each { |cat, value| @bias[cat.to_sym] = value.to_f } + @vocabulary = data['vocabulary'].to_h { |v| [v.to_sym, true] } + @training_data = (data['training_data'] || []).map do |d| + { category: d['category'].to_sym, features: d['features'].transform_keys(&:to_sym).transform_values(&:to_i) } end + end - @vocabulary = data['vocabulary'].to_h { |v| [v.to_sym, true] } + def restore_hyperparameters(data) @learning_rate = data['learning_rate'] @regularization = data['regularization'] @max_iterations = data['max_iterations'] @tolerance = data['tolerance'] - @training_data = [] - @fitted = true - @dirty = false - @storage = nil end end end diff --git a/lib/classifier/version.rb b/lib/classifier/version.rb new file mode 100644 index 0000000..97c6235 --- /dev/null +++ b/lib/classifier/version.rb @@ -0,0 +1,3 @@ +module Classifier + VERSION = '2.3.0'.freeze +end diff --git a/sig/classifier.rbs b/sig/classifier.rbs new file mode 100644 index 0000000..a8946ce --- /dev/null +++ b/sig/classifier.rbs @@ -0,0 +1,3 @@ +module Classifier + VERSION: String +end diff --git a/sig/vendor/json.rbs b/sig/vendor/json.rbs index 271b9d3..44202a4 100644 --- a/sig/vendor/json.rbs +++ b/sig/vendor/json.rbs @@ -1,4 +1,5 @@ module JSON def self.parse: (String source, ?symbolize_names: bool) -> untyped def self.generate: (untyped obj) -> String + def self.pretty_generate: (untyped obj) -> String end diff --git a/sig/vendor/optparse.rbs b/sig/vendor/optparse.rbs new file mode 100644 index 0000000..b6a0dae --- /dev/null +++ b/sig/vendor/optparse.rbs @@ -0,0 +1,19 @@ +# Minimal type definitions for optparse stdlib + +class OptionParser + class InvalidOption < StandardError + end + + class MissingArgument < StandardError + end + + class InvalidArgument < StandardError + end + + def initialize: () { (OptionParser) -> void } -> void + def banner=: (String) -> String + def separator: (String) -> void + def on: (*untyped) ?{ (*untyped) -> untyped } -> void + def to_s: () -> String + def parse!: (Array[String]) -> Array[String] +end diff --git a/test/cli/cli_test.rb b/test/cli/cli_test.rb new file mode 100644 index 0000000..d42f95a --- /dev/null +++ b/test/cli/cli_test.rb @@ -0,0 +1,279 @@ +require_relative '../test_helper' +require 'classifier/cli' + +class CLITest < Minitest::Test + def setup + @tmpdir = Dir.mktmpdir + @model_path = File.join(@tmpdir, 'classifier.json') + end + + def teardown + FileUtils.remove_entry(@tmpdir) if @tmpdir && File.exist?(@tmpdir) + end + + # Helper to run CLI and capture output + def run_cli(*args, stdin: nil) + cli = Classifier::CLI.new(args, stdin: stdin) + cli.run + end + + # Helper to create a trained model for testing + def create_trained_bayes_model + run_cli('train', 'spam', '-f', @model_path, stdin: "buy now\nfree money\nlimited offer") + run_cli('train', 'ham', '-f', @model_path, stdin: "hello friend\nmeeting tomorrow\nproject update") + end + + # + # Version and Help + # + def test_version_flag + result = run_cli('-v') + + assert_match(/\d+\.\d+\.\d+/, result[:output]) + assert_equal 0, result[:exit_code] + end + + def test_help_flag + result = run_cli('-h') + + assert_match(/usage:/i, result[:output]) + assert_match(/train/, result[:output]) + assert_match(/classify/i, result[:output]) + assert_equal 0, result[:exit_code] + end + + def test_getting_started_when_no_model_and_no_args + result = run_cli('-f', @model_path) + + assert_match(/Get started by training/, result[:output]) + assert_match(/classifier train spam/, result[:output]) + assert_match(/classifier --help/, result[:output]) + assert_equal 0, result[:exit_code] + assert_empty result[:error] + end + + # + # Train Command + # + def test_train_from_stdin + result = run_cli('train', 'spam', '-f', @model_path, stdin: "buy now\nfree money") + + assert_equal 0, result[:exit_code] + assert_path_exists @model_path + end + + def test_train_from_file + corpus_file = File.join(@tmpdir, 'spam.txt') + File.write(corpus_file, "buy now\nfree money\nlimited offer") + + result = run_cli('train', 'spam', '-f', @model_path, corpus_file) + + assert_equal 0, result[:exit_code] + assert_path_exists @model_path + end + + def test_train_multiple_files + file1 = File.join(@tmpdir, 'spam1.txt') + file2 = File.join(@tmpdir, 'spam2.txt') + File.write(file1, 'buy now') + File.write(file2, 'free money') + + result = run_cli('train', 'spam', '-f', @model_path, file1, file2) + + assert_equal 0, result[:exit_code] + end + + def test_train_requires_category + result = run_cli('train', '-f', @model_path) + + assert_equal 2, result[:exit_code] + assert_match(/category/i, result[:error]) + end + + def test_train_multiple_categories + run_cli('train', 'spam', '-f', @model_path, stdin: 'buy now') + result = run_cli('train', 'ham', '-f', @model_path, stdin: 'hello friend') + + assert_equal 0, result[:exit_code] + + # Verify both categories exist + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + assert_includes info['categories'], 'Spam' + assert_includes info['categories'], 'Ham' + end + + # + # Classify Command (Default Action) + # + def test_classify_text_argument + create_trained_bayes_model + + result = run_cli('buy now free money', '-f', @model_path) + + assert_equal 0, result[:exit_code] + assert_equal 'spam', result[:output].strip.downcase + end + + def test_classify_from_stdin + create_trained_bayes_model + + result = run_cli('-f', @model_path, stdin: 'buy now free money') + + assert_equal 0, result[:exit_code] + assert_equal 'spam', result[:output].strip.downcase + end + + def test_classify_multiple_lines_from_stdin + create_trained_bayes_model + + result = run_cli('-f', @model_path, stdin: "buy now\nmeeting tomorrow") + lines = result[:output].strip.split("\n").map(&:downcase) + + assert_equal 2, lines.size + assert_equal 'spam', lines[0] + assert_equal 'ham', lines[1] + end + + def test_classify_with_probabilities + create_trained_bayes_model + + result = run_cli('-p', 'buy now free money', '-f', @model_path) + + assert_equal 0, result[:exit_code] + assert_match(/spam:\d+\.\d+/, result[:output].downcase) + assert_match(/ham:\d+\.\d+/, result[:output].downcase) + end + + def test_classify_without_model_fails + result = run_cli('some text', '-f', '/nonexistent/model.json') + + assert_equal 1, result[:exit_code] + assert_match(/model|not found|exist/i, result[:error]) + end + + # + # Info Command + # + def test_info_shows_model_details + create_trained_bayes_model + + result = run_cli('info', '-f', @model_path) + + assert_equal 0, result[:exit_code] + info = JSON.parse(result[:output]) + + assert_equal 'bayes', info['type'] + assert_includes info['categories'], 'Spam' + assert_includes info['categories'], 'Ham' + assert_operator info['category_stats']['Spam']['unique_words'], :>, 0 + assert_operator info['category_stats']['Ham']['unique_words'], :>, 0 + end + + def test_info_without_model_fails + result = run_cli('info', '-f', '/nonexistent/model.json') + + assert_equal 1, result[:exit_code] + end + + # + # Classifier Types + # + def test_train_with_lsi_type + result = run_cli('-m', 'lsi', 'train', 'tech', '-f', @model_path, stdin: 'ruby programming language') + + assert_equal 0, result[:exit_code] + + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + assert_equal 'lsi', info['type'] + end + + def test_train_with_knn_type + result = run_cli('-m', 'knn', 'train', 'tech', '-f', @model_path, stdin: 'ruby programming language') + + assert_equal 0, result[:exit_code] + + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + assert_equal 'knn', info['type'] + end + + def test_train_with_lr_type + result = run_cli('-m', 'lr', 'train', 'tech', '-f', @model_path, stdin: 'ruby programming language') + + assert_equal 0, result[:exit_code] + + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + assert_equal 'logistic_regression', info['type'] + end + + def test_invalid_classifier_type + result = run_cli('-m', 'invalid', 'train', 'spam', '-f', @model_path, stdin: 'test') + + assert_equal 2, result[:exit_code] + assert_match(/invalid|unknown|type/i, result[:error]) + end + + # + # KNN Options + # + def test_knn_with_k_option + run_cli('-m', 'knn', '-k', '3', 'train', 'tech', '-f', @model_path, stdin: 'ruby programming') + run_cli('-m', 'knn', 'train', 'sports', '-f', @model_path, stdin: 'football soccer') + + result = run_cli('-m', 'knn', '-k', '3', 'ruby code', '-f', @model_path) + + assert_equal 0, result[:exit_code] + end + + # + # Environment Variables + # + def test_model_from_environment_variable + create_trained_bayes_model + + ENV['CLASSIFIER_MODEL'] = @model_path + result = run_cli('buy now free money') + + assert_equal 0, result[:exit_code] + assert_equal 'spam', result[:output].strip.downcase + ensure + ENV.delete('CLASSIFIER_MODEL') + end + + def test_type_from_environment_variable + ENV['CLASSIFIER_TYPE'] = 'lsi' + ENV['CLASSIFIER_MODEL'] = @model_path + + result = run_cli('train', 'tech', stdin: 'programming code') + + assert_equal 0, result[:exit_code] + + result = run_cli('info') + info = JSON.parse(result[:output]) + + assert_equal 'lsi', info['type'] + ensure + ENV.delete('CLASSIFIER_TYPE') + ENV.delete('CLASSIFIER_MODEL') + end + + # + # Quiet Mode + # + def test_quiet_mode_minimal_output + create_trained_bayes_model + + result = run_cli('-q', 'buy now', '-f', @model_path) + + assert_equal 0, result[:exit_code] + # Quiet mode should just output the category, nothing else + assert_equal 'spam', result[:output].strip.downcase + end +end diff --git a/test/cli/lr_commands_test.rb b/test/cli/lr_commands_test.rb new file mode 100644 index 0000000..00a3aee --- /dev/null +++ b/test/cli/lr_commands_test.rb @@ -0,0 +1,142 @@ +require_relative '../test_helper' +require 'classifier/cli' + +class LRCommandsTest < Minitest::Test + def setup + @tmpdir = Dir.mktmpdir + @model_path = File.join(@tmpdir, 'classifier.json') + end + + def teardown + FileUtils.remove_entry(@tmpdir) if @tmpdir && File.exist?(@tmpdir) + end + + def run_cli(*args, stdin: nil) + cli = Classifier::CLI.new(args, stdin: stdin) + cli.run + end + + def create_trained_lr_model + run_cli('-m', 'lr', 'train', 'spam', '-f', @model_path, stdin: "buy now\nfree money\nlimited offer\nclick here") + run_cli('-m', 'lr', 'train', 'ham', '-f', @model_path, stdin: "hello friend\nmeeting tomorrow\nproject update\nweekly report") + end + + def create_fitted_lr_model + create_trained_lr_model + run_cli('fit', '-f', @model_path) + end + + # + # Explicit Fit Required + # + def test_classify_without_fit_fails + create_trained_lr_model + + # Should fail - model not fitted + result = run_cli('-m', 'lr', 'buy now free money', '-f', @model_path) + + assert_equal 1, result[:exit_code] + assert_match(/not fitted|run.*fit/i, result[:error]) + end + + def test_classify_after_fit_succeeds + create_fitted_lr_model + + result = run_cli('-m', 'lr', 'buy now free money', '-f', @model_path) + + assert_equal 0, result[:exit_code] + assert_equal 'spam', result[:output].strip.downcase + end + + def test_lr_info_shows_fit_status_before_fit + create_trained_lr_model + + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + refute info['fitted'] + end + + def test_lr_info_shows_fit_status_after_fit + create_fitted_lr_model + + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + assert info['fitted'] + end + + # + # Fit Command + # + def test_fit_command + create_trained_lr_model + + result = run_cli('fit', '-f', @model_path) + + assert_equal 0, result[:exit_code] + + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + assert info['fitted'] + end + + def test_fit_on_bayes_is_noop + # Create a bayes model + run_cli('train', 'spam', '-f', @model_path, stdin: 'buy now') + + result = run_cli('fit', '-f', @model_path) + + # Should succeed (no-op for Bayes) + assert_equal 0, result[:exit_code] + end + + def test_fit_after_additional_training_invalidates + create_fitted_lr_model + + # Add more training data + run_cli('-m', 'lr', 'train', 'spam', '-f', @model_path, stdin: 'win big prizes') + + # Info should show needs re-fitting + result = run_cli('info', '-f', @model_path) + info = JSON.parse(result[:output]) + + refute info['fitted'] + end + + # + # LR Hyperparameters + # + def test_lr_with_learning_rate + result = run_cli('-m', 'lr', '--learning-rate', '0.001', 'train', 'spam', '-f', @model_path, stdin: 'buy now') + + assert_equal 0, result[:exit_code] + end + + def test_lr_with_regularization + result = run_cli('-m', 'lr', '--regularization', '0.1', 'train', 'spam', '-f', @model_path, stdin: 'buy now') + + assert_equal 0, result[:exit_code] + end + + def test_lr_with_max_iterations + result = run_cli('-m', 'lr', '--max-iterations', '500', 'train', 'spam', '-f', @model_path, stdin: 'buy now') + + assert_equal 0, result[:exit_code] + end + + # + # LR Classification with Probabilities + # + def test_lr_classify_with_probabilities + create_fitted_lr_model + + result = run_cli('-m', 'lr', '-p', 'buy now free money', '-f', @model_path) + + assert_equal 0, result[:exit_code] + # LR should give good probability estimates + assert_match(/spam:\d+\.\d+/, result[:output].downcase) + assert_match(/ham:\d+\.\d+/, result[:output].downcase) + end +end diff --git a/test/cli/lsi_commands_test.rb b/test/cli/lsi_commands_test.rb new file mode 100644 index 0000000..34c9bb1 --- /dev/null +++ b/test/cli/lsi_commands_test.rb @@ -0,0 +1,131 @@ +require_relative '../test_helper' +require 'classifier/cli' + +class LSICommandsTest < Minitest::Test + def setup + @tmpdir = Dir.mktmpdir + @model_path = File.join(@tmpdir, 'classifier.json') + create_trained_lsi_model + end + + def teardown + FileUtils.remove_entry(@tmpdir) if @tmpdir && File.exist?(@tmpdir) + end + + def run_cli(*args, stdin: nil) + cli = Classifier::CLI.new(args, stdin: stdin) + cli.run + end + + def create_trained_lsi_model + # Create article files for training + @articles = {} + + @articles['ruby.txt'] = File.join(@tmpdir, 'ruby.txt') + File.write(@articles['ruby.txt'], 'Ruby is an elegant programming language for web development') + + @articles['python.txt'] = File.join(@tmpdir, 'python.txt') + File.write(@articles['python.txt'], 'Python is a programming language for data science') + + @articles['rails.txt'] = File.join(@tmpdir, 'rails.txt') + File.write(@articles['rails.txt'], 'Rails is a web framework built with Ruby programming') + + @articles['football.txt'] = File.join(@tmpdir, 'football.txt') + File.write(@articles['football.txt'], 'Football is a popular sport with teams and goals') + + # Train LSI model + run_cli('-m', 'lsi', 'train', 'tech', '-f', @model_path, @articles['ruby.txt'], @articles['python.txt'], @articles['rails.txt']) + run_cli('-m', 'lsi', 'train', 'sports', '-f', @model_path, @articles['football.txt']) + end + + # + # Search Command + # + def test_search_returns_ranked_documents + result = run_cli('search', 'programming language', '-f', @model_path) + + assert_equal 0, result[:exit_code] + # Should return documents with scores + assert_match(/\.txt:\d+\.\d+/, result[:output]) + end + + def test_search_from_stdin + result = run_cli('search', '-f', @model_path, stdin: 'web development') + + assert_equal 0, result[:exit_code] + assert_match(/\.txt:\d+\.\d+/, result[:output]) + end + + def test_search_with_count_limit + result = run_cli('search', '-n', '2', 'programming', '-f', @model_path) + + assert_equal 0, result[:exit_code] + lines = result[:output].strip.split("\n") + + assert_operator lines.size, :<=, 2 + end + + def test_search_fails_on_bayes_model + bayes_model = File.join(@tmpdir, 'bayes.json') + run_cli('train', 'spam', '-f', bayes_model, stdin: 'buy now') + + result = run_cli('search', 'query', '-f', bayes_model) + + assert_equal 1, result[:exit_code] + assert_match(/lsi|search.*requires/i, result[:error]) + end + + # + # Related Command + # + def test_related_finds_similar_documents + result = run_cli('related', @articles['ruby.txt'], '-f', @model_path) + + assert_equal 0, result[:exit_code] + # Should find rails.txt as related (both about Ruby) + assert_match(/\.txt:\d+\.\d+/, result[:output]) + end + + def test_related_with_count_limit + result = run_cli('related', '-n', '1', @articles['ruby.txt'], '-f', @model_path) + + assert_equal 0, result[:exit_code] + lines = result[:output].strip.split("\n") + + assert_equal 1, lines.size + end + + def test_related_fails_on_bayes_model + bayes_model = File.join(@tmpdir, 'bayes.json') + run_cli('train', 'spam', '-f', bayes_model, stdin: 'buy now') + + result = run_cli('related', 'some_file.txt', '-f', bayes_model) + + assert_equal 1, result[:exit_code] + assert_match(/lsi|related.*requires/i, result[:error]) + end + + def test_related_with_nonexistent_item + result = run_cli('related', '/nonexistent/file.txt', '-f', @model_path) + + assert_equal 1, result[:exit_code] + assert_match(/not found|unknown|item/i, result[:error]) + end + + # + # LSI Classification + # + def test_lsi_classify_text + result = run_cli('-m', 'lsi', 'Ruby web framework', '-f', @model_path) + + assert_equal 0, result[:exit_code] + assert_equal 'tech', result[:output].strip.downcase + end + + def test_lsi_classify_with_probabilities + result = run_cli('-m', 'lsi', '-p', 'Ruby web framework', '-f', @model_path) + + assert_equal 0, result[:exit_code] + assert_match(/tech:\d+\.\d+/, result[:output].downcase) + end +end diff --git a/test/cli/registry_commands_test.rb b/test/cli/registry_commands_test.rb new file mode 100644 index 0000000..613d8c3 --- /dev/null +++ b/test/cli/registry_commands_test.rb @@ -0,0 +1,340 @@ +require_relative '../test_helper' +require 'classifier/cli' +require 'webmock/minitest' + +class RegistryCommandsTest < Minitest::Test + def setup + @tmpdir = Dir.mktmpdir + @model_path = File.join(@tmpdir, 'classifier.json') + @cache_dir = File.join(@tmpdir, 'cache') + + # Override cache directory for tests + @original_cache = Classifier::CLI::CACHE_DIR + Classifier::CLI.send(:remove_const, :CACHE_DIR) + Classifier::CLI.const_set(:CACHE_DIR, @cache_dir) + + # Mock models.json response + @models_json = { + 'version' => '1.0.0', + 'models' => { + 'spam-filter' => { + 'description' => 'Email spam detection', + 'type' => 'bayes', + 'categories' => %w[spam ham], + 'file' => 'models/spam-filter.json', + 'size' => '245KB' + }, + 'sentiment' => { + 'description' => 'Sentiment analysis', + 'type' => 'bayes', + 'categories' => %w[positive negative neutral], + 'file' => 'models/sentiment.json', + 'size' => '1.2MB' + } + } + }.to_json + + # Mock classifier model response + @model_json = Classifier::Bayes.new('Spam', 'Ham').tap do |b| + b.train('Spam', 'buy now free money cheap') + b.train('Ham', 'hello friend meeting project') + end.to_json + end + + def teardown + FileUtils.remove_entry(@tmpdir) if @tmpdir && File.exist?(@tmpdir) + + # Restore original cache directory + Classifier::CLI.send(:remove_const, :CACHE_DIR) + Classifier::CLI.const_set(:CACHE_DIR, @original_cache) + + WebMock.reset! + end + + def run_cli(*args, stdin: nil) + cli = Classifier::CLI.new(args, stdin: stdin) + cli.run + end + + # + # Models Command + # + def test_models_lists_available_models + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + + result = run_cli('models') + + assert_equal 0, result[:exit_code] + assert_match(/spam-filter/, result[:output]) + assert_match(/sentiment/, result[:output]) + assert_match(/bayes/, result[:output]) + assert_empty result[:error] + end + + def test_models_from_custom_registry + stub_request(:get, 'https://raw.githubusercontent.com/someone/models/main/models.json') + .to_return(status: 200, body: @models_json) + + result = run_cli('models', '@someone/models') + + assert_equal 0, result[:exit_code] + assert_match(/spam-filter/, result[:output]) + assert_empty result[:error] + end + + def test_models_handles_empty_registry + empty_json = { 'version' => '1.0.0', 'models' => {} }.to_json + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: empty_json) + + result = run_cli('models') + + assert_equal 0, result[:exit_code] + assert_match(/no models found/i, result[:output]) + end + + def test_models_handles_network_error + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 404) + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/master/models.json') + .to_return(status: 404) + + result = run_cli('models') + + assert_equal 1, result[:exit_code] + assert_match(/failed to fetch/i, result[:error]) + end + + def test_models_local_lists_cached_models + # Create some cached models + models_dir = File.join(@cache_dir, 'models') + FileUtils.mkdir_p(models_dir) + File.write(File.join(models_dir, 'spam-filter.json'), @model_json) + File.write(File.join(models_dir, 'sentiment.json'), @model_json) + + result = run_cli('models', '--local') + + assert_equal 0, result[:exit_code] + assert_match(/spam-filter/, result[:output]) + assert_match(/sentiment/, result[:output]) + assert_match(/bayes/, result[:output]) + assert_empty result[:error] + end + + def test_models_local_lists_models_from_custom_registries + # Create cached model from custom registry + custom_dir = File.join(@cache_dir, 'models', '@someone/models') + FileUtils.mkdir_p(custom_dir) + File.write(File.join(custom_dir, 'custom-model.json'), @model_json) + + result = run_cli('models', '--local') + + assert_equal 0, result[:exit_code] + assert_match(%r{@someone/models:custom-model}, result[:output]) + end + + def test_models_local_shows_no_models_when_cache_empty + result = run_cli('models', '--local') + + assert_equal 0, result[:exit_code] + assert_match(/no local models found/i, result[:output]) + end + + def test_models_local_shows_no_models_when_cache_dir_missing + # Cache dir doesn't exist by default in test setup + FileUtils.rm_rf(@cache_dir) + + result = run_cli('models', '--local') + + assert_equal 0, result[:exit_code] + assert_match(/no local models found/i, result[:output]) + end + + # + # Pull Command + # + def test_pull_downloads_model + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('pull', 'spam-filter') + + assert_equal 0, result[:exit_code] + assert_match(/downloading/i, result[:output]) + assert_match(/saved/i, result[:output]) + + cached_path = File.join(@cache_dir, 'models', 'spam-filter.json') + + assert_path_exists cached_path + end + + def test_pull_with_custom_output_path + output_path = File.join(@tmpdir, 'my-model.json') + + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('pull', 'spam-filter', '-o', output_path) + + assert_equal 0, result[:exit_code] + assert_path_exists output_path + end + + def test_pull_from_custom_registry + stub_request(:get, 'https://raw.githubusercontent.com/someone/models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/someone/models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('pull', '@someone/models:spam-filter') + + assert_equal 0, result[:exit_code] + + cached_path = File.join(@cache_dir, 'models', '@someone/models', 'spam-filter.json') + + assert_path_exists cached_path + end + + def test_pull_model_not_found + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + + result = run_cli('pull', 'nonexistent') + + assert_equal 1, result[:exit_code] + assert_match(/not found/i, result[:error]) + end + + def test_pull_requires_model_name + result = run_cli('pull') + + assert_equal 2, result[:exit_code] + assert_match(/model name required/i, result[:error]) + end + + def test_pull_quiet_mode + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('pull', 'spam-filter', '-q') + + assert_equal 0, result[:exit_code] + assert_empty result[:output] + end + + # + # Push Command + # + def test_push_shows_instructions + result = run_cli('push', 'my-model.json') + + assert_equal 0, result[:exit_code] + assert_match(/fork/i, result[:output]) + assert_match(/classifier-models/i, result[:output]) + assert_match(/pull request/i, result[:output]) + end + + # + # Remote Classification (-r) + # + def test_classify_with_remote_model + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('-r', 'spam-filter', 'buy now free money') + + assert_equal 0, result[:exit_code] + # Last line should be the classification result + assert_equal 'spam', result[:output].strip.split("\n").last.downcase + end + + def test_classify_with_cached_remote_model + # Pre-cache the model + cached_path = File.join(@cache_dir, 'models', 'spam-filter.json') + FileUtils.mkdir_p(File.dirname(cached_path)) + File.write(cached_path, @model_json) + + # Should not make any network requests since model is cached + result = run_cli('-r', 'spam-filter', 'buy now free money') + + assert_equal 0, result[:exit_code] + assert_equal 'spam', result[:output].strip.downcase + end + + def test_classify_with_remote_from_custom_registry + stub_request(:get, 'https://raw.githubusercontent.com/someone/models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/someone/models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('-r', '@someone/models:spam-filter', 'buy now free money') + + assert_equal 0, result[:exit_code] + # Last line should be the classification result + assert_equal 'spam', result[:output].strip.split("\n").last.downcase + end + + def test_classify_with_probabilities_and_remote + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models.json') + .to_return(status: 200, body: @models_json) + stub_request(:get, 'https://raw.githubusercontent.com/cardmagic/classifier-models/main/models/spam-filter.json') + .to_return(status: 200, body: @model_json) + + result = run_cli('-r', 'spam-filter', '-p', 'buy now free money') + + assert_equal 0, result[:exit_code] + assert_match(/spam:\d+\.\d+/, result[:output].downcase) + assert_match(/ham:\d+\.\d+/, result[:output].downcase) + end + + # + # Helper Methods + # + def test_parse_model_spec_simple_name + cli = Classifier::CLI.new([]) + registry, model = cli.send(:parse_model_spec, 'sentiment') + + assert_nil registry + assert_equal 'sentiment', model + end + + def test_parse_model_spec_custom_registry + cli = Classifier::CLI.new([]) + registry, model = cli.send(:parse_model_spec, '@user/repo:sentiment') + + assert_equal 'user/repo', registry + assert_equal 'sentiment', model + end + + def test_parse_model_spec_registry_only + cli = Classifier::CLI.new([]) + registry, model = cli.send(:parse_model_spec, '@user/repo') + + assert_equal 'user/repo', registry + assert_nil model + end + + def test_cache_path_for_default_registry + cli = Classifier::CLI.new([]) + path = cli.send(:cache_path_for, 'cardmagic/classifier-models', 'sentiment') + + assert_match %r{models/sentiment\.json$}, path + refute_match(/@/, path) + end + + def test_cache_path_for_custom_registry + cli = Classifier::CLI.new([]) + path = cli.send(:cache_path_for, 'user/repo', 'sentiment') + + assert_match %r{@user/repo/sentiment\.json$}, path + end +end diff --git a/test/logistic_regression/logistic_regression_test.rb b/test/logistic_regression/logistic_regression_test.rb index 419121f..5b3c089 100644 --- a/test/logistic_regression/logistic_regression_test.rb +++ b/test/logistic_regression/logistic_regression_test.rb @@ -5,10 +5,21 @@ def setup @classifier = Classifier::LogisticRegression.new 'Spam', 'Ham' end + # Helper: train and fit for tests that need classification + def train_and_fit + @classifier.train_spam 'buy now free money offer limited' + @classifier.train_ham 'hello friend meeting project update' + @classifier.fit + end + # Initialization tests - def test_requires_at_least_two_categories - assert_raises(ArgumentError) { Classifier::LogisticRegression.new 'Only' } + def test_requires_at_least_two_categories_at_fit + classifier = Classifier::LogisticRegression.new 'Only' + classifier.train(:only, 'test text') + + # Error raised at fit time, not initialization + assert_raises(ArgumentError) { classifier.fit } end def test_accepts_symbols_and_strings @@ -25,8 +36,12 @@ def test_accepts_array_of_categories assert_equal %w[Ham Spam], classifier.categories.sort end - def test_array_initialization_requires_at_least_two - assert_raises(ArgumentError) { Classifier::LogisticRegression.new(['Only']) } + def test_array_initialization_requires_at_least_two_at_fit + classifier = Classifier::LogisticRegression.new(['Only']) + classifier.train(:only, 'test text') + + # Error raised at fit time, not initialization + assert_raises(ArgumentError) { classifier.fit } end def test_custom_hyperparameters @@ -50,6 +65,7 @@ def test_categories def test_train_with_positional_arguments @classifier.train :spam, 'Buy now! Free money!' @classifier.train :ham, 'Hello friend, meeting tomorrow' + @classifier.fit assert_equal 'Spam', @classifier.classify('Buy free money') assert_equal 'Ham', @classifier.classify('Hello meeting friend') @@ -58,6 +74,7 @@ def test_train_with_positional_arguments def test_train_with_keyword_arguments @classifier.train(spam: 'Buy now! Free money!') @classifier.train(ham: 'Hello friend, meeting tomorrow') + @classifier.fit assert_equal 'Spam', @classifier.classify('Buy free money') assert_equal 'Ham', @classifier.classify('Hello meeting friend') @@ -66,6 +83,7 @@ def test_train_with_keyword_arguments def test_train_with_array_value @classifier.train(spam: ['Buy now!', 'Free money!', 'Click here!']) @classifier.train(ham: 'Normal email content') + @classifier.fit assert_equal 'Spam', @classifier.classify('Buy click free') end @@ -75,6 +93,7 @@ def test_train_with_multiple_categories spam: ['Buy now!', 'Free money!'], ham: ['Hello friend', 'Meeting tomorrow'] ) + @classifier.fit assert_equal 'Spam', @classifier.classify('Buy free') assert_equal 'Ham', @classifier.classify('Hello meeting') @@ -83,6 +102,7 @@ def test_train_with_multiple_categories def test_train_dynamic_method @classifier.train_spam 'Buy now! Free money!' @classifier.train_ham 'Hello friend' + @classifier.fit assert_equal 'Spam', @classifier.classify('Buy free money') assert_equal 'Ham', @classifier.classify('Hello friend') @@ -98,6 +118,7 @@ def test_train_invalid_category def test_classify_basic @classifier.train_spam 'Buy now! Free money! Limited offer!' @classifier.train_ham 'Hello, how are you? Meeting tomorrow.' + @classifier.fit assert_equal 'Spam', @classifier.classify('Free money offer') assert_equal 'Ham', @classifier.classify('Hello, how are you?') @@ -121,6 +142,7 @@ def test_classify_with_more_training_data 'Thanks for your help yesterday', 'Looking forward to seeing you' ]) + @classifier.fit assert_equal 'Spam', @classifier.classify('Free prize money') assert_equal 'Ham', @classifier.classify('Project meeting tomorrow') @@ -129,6 +151,7 @@ def test_classify_with_more_training_data def test_classifications_returns_scores @classifier.train_spam 'spam words' @classifier.train_ham 'ham words' + @classifier.fit scores = @classifier.classifications('spam words') @@ -143,6 +166,7 @@ def test_classifications_returns_scores def test_probabilities_sum_to_one @classifier.train_spam 'spam words here' @classifier.train_ham 'ham words here' + @classifier.fit probs = @classifier.probabilities('test words') @@ -152,6 +176,7 @@ def test_probabilities_sum_to_one def test_probabilities_are_between_zero_and_one @classifier.train_spam 'spam words here' @classifier.train_ham 'ham words here' + @classifier.fit probs = @classifier.probabilities('test words') @@ -164,6 +189,7 @@ def test_probabilities_are_between_zero_and_one def test_probabilities_reflect_confidence @classifier.train(spam: ['spam spam spam'] * 10) @classifier.train(ham: ['ham ham ham'] * 10) + @classifier.fit spam_probs = @classifier.probabilities('spam spam spam') ham_probs = @classifier.probabilities('ham ham ham') @@ -177,6 +203,7 @@ def test_probabilities_reflect_confidence def test_weights_returns_hash @classifier.train_spam 'buy free money' @classifier.train_ham 'hello friend meeting' + @classifier.fit weights = @classifier.weights(:spam) @@ -187,6 +214,7 @@ def test_weights_returns_hash def test_weights_sorted_by_importance @classifier.train_spam 'spam spam spam important' @classifier.train_ham 'ham ham ham' + @classifier.fit weights = @classifier.weights(:spam) values = weights.values @@ -200,6 +228,7 @@ def test_weights_sorted_by_importance def test_weights_with_limit @classifier.train_spam 'one two three four five' @classifier.train_ham 'six seven eight nine ten' + @classifier.fit weights = @classifier.weights(:spam, limit: 3) @@ -209,6 +238,7 @@ def test_weights_with_limit def test_weights_invalid_category @classifier.train_spam 'spam' @classifier.train_ham 'ham' + @classifier.fit assert_raises(StandardError) { @classifier.weights(:invalid) } end @@ -223,6 +253,7 @@ def test_fitted_state refute_predicate @classifier, :fitted? + @classifier.fit @classifier.classify('test') assert_predicate @classifier, :fitted? @@ -238,24 +269,31 @@ def test_fit_explicitly assert_predicate @classifier, :fitted? end - def test_auto_fit_on_classify + def test_classify_without_fit_raises_error @classifier.train_spam 'spam' @classifier.train_ham 'ham' refute_predicate @classifier, :fitted? - @classifier.classify('test') + assert_raises(Classifier::NotFittedError) { @classifier.classify('test') } + end - assert_predicate @classifier, :fitted? + def test_probabilities_without_fit_raises_error + @classifier.train_spam 'spam' + @classifier.train_ham 'ham' + + refute_predicate @classifier, :fitted? + assert_raises(Classifier::NotFittedError) { @classifier.probabilities('test') } end - def test_auto_fit_on_probabilities + def test_explicit_fit_enables_classify @classifier.train_spam 'spam' @classifier.train_ham 'ham' refute_predicate @classifier, :fitted? - @classifier.probabilities('test') + @classifier.fit assert_predicate @classifier, :fitted? + assert_equal 'Spam', @classifier.classify('spam words') end # Multi-class tests @@ -266,6 +304,7 @@ def test_three_class_classification classifier.train(positive: ['great amazing wonderful love happy']) classifier.train(negative: ['terrible awful hate bad angry']) classifier.train(neutral: ['okay average normal regular']) + classifier.fit assert_equal 'Positive', classifier.classify('great love happy') assert_equal 'Negative', classifier.classify('terrible hate angry') @@ -276,6 +315,7 @@ def test_multi_class_probabilities_sum_to_one classifier = Classifier::LogisticRegression.new :a, :b, :c, :d classifier.train(a: 'alpha', b: 'beta', c: 'gamma', d: 'delta') + classifier.fit probs = classifier.probabilities('test') @@ -313,6 +353,7 @@ def test_to_json def test_from_json_with_string @classifier.train_spam 'spam words' @classifier.train_ham 'ham words' + @classifier.fit json = @classifier.to_json loaded = Classifier::LogisticRegression.from_json(json) @@ -325,6 +366,7 @@ def test_from_json_with_string def test_from_json_with_hash @classifier.train_spam 'spam words' @classifier.train_ham 'ham words' + @classifier.fit hash = JSON.parse(@classifier.to_json) loaded = Classifier::LogisticRegression.from_json(hash) @@ -341,6 +383,7 @@ def test_from_json_invalid_type def test_save_and_load_file @classifier.train_spam 'spam words' @classifier.train_ham 'ham words' + @classifier.fit Dir.mktmpdir do |dir| path = File.join(dir, 'classifier.json') @@ -358,6 +401,7 @@ def test_save_and_load_file def test_loaded_classifier_preserves_predictions @classifier.train(spam: ['buy free money offer'] * 5) @classifier.train(ham: ['hello meeting project friend'] * 5) + @classifier.fit Dir.mktmpdir do |dir| path = File.join(dir, 'classifier.json') @@ -428,6 +472,7 @@ def test_empty_string_training def test_empty_string_classification @classifier.train_spam 'spam words' @classifier.train_ham 'ham words' + @classifier.fit result = @classifier.classify('') @@ -437,6 +482,7 @@ def test_empty_string_classification def test_unicode_text @classifier.train_spam 'spam japonais 日本語' @classifier.train_ham 'ham chinese 中文' + @classifier.fit # Should handle unicode without crashing result = @classifier.classify('日本語 test') @@ -447,6 +493,7 @@ def test_unicode_text def test_single_word_documents @classifier.train_spam 'spam' @classifier.train_ham 'ham' + @classifier.fit assert_equal 'Spam', @classifier.classify('spam') assert_equal 'Ham', @classifier.classify('ham') @@ -458,6 +505,7 @@ def test_very_long_text @classifier.train_spam long_spam @classifier.train_ham long_ham + @classifier.fit assert_equal 'Spam', @classifier.classify('buy free money') end @@ -465,6 +513,7 @@ def test_very_long_text def test_special_characters @classifier.train_spam 'Buy! @#$% now!!!' @classifier.train_ham 'Hello... how are you???' + @classifier.fit # Should not crash on special characters @classifier.classify('!@#$%^&*()') @@ -478,6 +527,7 @@ def test_softmax_numerical_stability @classifier.train_spam 'spam spam spam spam spam' @classifier.train_ham 'ham ham ham ham ham' end + @classifier.fit probs = @classifier.probabilities('spam spam spam') @@ -515,6 +565,7 @@ def test_convergence_with_separable_data # Clear separation between classes should converge quickly @classifier.train(spam: ['spam spam spam'] * 20) @classifier.train(ham: ['ham ham ham'] * 20) + @classifier.fit # Should be able to perfectly classify training data probs = @classifier.probabilities('spam spam spam') diff --git a/test/storage/storage_test.rb b/test/storage/storage_test.rb index b72c246..c581617 100644 --- a/test/storage/storage_test.rb +++ b/test/storage/storage_test.rb @@ -668,6 +668,7 @@ def test_reload_restores_saved_state @classifier.storage = storage @classifier.train_spam 'buy now limited offer' @classifier.train_ham 'hello friend meeting' + @classifier.fit @classifier.save original_classification = @classifier.classify('buy now') @@ -709,6 +710,7 @@ def test_load_with_storage @classifier.storage = storage @classifier.train_spam 'buy now limited offer' @classifier.train_ham 'hello friend meeting' + @classifier.fit @classifier.save loaded = Classifier::LogisticRegression.load(storage: storage) @@ -729,10 +731,12 @@ def test_loaded_classifier_can_save_immediately @classifier.storage = storage @classifier.train_spam 'buy now' @classifier.train_ham 'hello friend' + @classifier.fit @classifier.save loaded = Classifier::LogisticRegression.load(storage: storage) loaded.train_spam 'more spam words' + loaded.fit loaded.save reloaded = Classifier::LogisticRegression.load(storage: storage)