`
ay_guobo
  • 浏览: 116036 次
  • 性别: Icon_minigender_1
  • 来自: 札幌
社区版块
存档分类
最新评论

基于python的支持向量机

 
阅读更多

这个要和svmlight配合一起使用

 

001 # svmlight.py
002 #
003 # Author: Clint Burfoot <clint@burfoot.info>
004 #
005 
006 """
007 An interface class for U{SVM light<http://svmlight.joachims.org/>}
008 """
009 
010 import os
011 import tempfile
012 import math
013 from subprocess import call
014 
015 class SVMLight:
016     """
017     An interface class for U{SVM light<http://svmlight.joachims.org/>}
018     
019     This class currently supports classification with default options 
020     only. It calls the SVMLight binaries as external programs.
021     
022     Future versions should add a SWIG interface and support for use of 
023     non-default SVMlight options.
024     
025     C{SVMLight} reads sparse feature vectors - dictionaries with 
026     numeric keys, representing features, and arbitrary numeric values.
027     """
028     
029     learn_binary = "svm_learn"
030     classify_binary = "svm_classify"
031     
032     def __init__(self, svm_path, labels=None, vectors=None, model=None, 
033                  cleanup=False):
034         """
035         Trains a new classifier.
036         
037         @type svm_path: C{str}
038         @param svm_path: The filesystem path to the SVMLight binaries
039         @type labels: C{tuple}
040         @param labels: A tuple of boolean training set labels.
041         @type vectors: C{tuple}
042         @param vectors: A tuple of sparse feature vectors.
043         @type model: A C{tuple} of C{str}
044         @param model: The lines from an SVMlight model file. Specify this 
045         instead of C{labels} and C{vectors} to use a pre-trained classifier.
046         """
047         self._svm_learn = os.sep.join((svm_path, SVMLight.learn_binary))
048         self._svm_classify = os.sep.join((svm_path, SVMLight.classify_binary))
049         self._cleanup = cleanup
050         self._devnull = None
051         
052         self._directory = tempfile.mkdtemp()
053         self._example_fname = os.sep.join((self._directory, "example"))
054         self._model_fname = os.sep.join((self._directory, "model"))
055         self._input_fname = os.sep.join((self._directory, "input"))
056         self._output_fname = os.sep.join((self._directory, "output"))
057         
058         if model is not None:
059             self._write_model(self._model_fname, model)
060             self.model = model
061         elif len(labels!= len(vectors):
062             raise ValueError("labels and vectors arrays are different lengths")
063         
064         self._write_vectors(self._example_fname, labels, vectors)
065         ret = call((self._svm_learn, self._example_fname, self._model_fname),
066                    stdout=self.devnull)
067         assert ret == 0
068         if model is None:
069             self.model = self._read_model()
070 
071     def _get_devnull(self):
072         # Return a handle to /dev/null (or windows equivalent).
073         if self._devnull is None:
074             if os.name == 'posix':
075                 self._devnull = open("/dev/null", "w")
076             else:
077                 # Assume we're on windows.
078                 self._devnull = open("NUL:", "w")              
079         return self._devnull
080     devnull = property(_get_devnull)
081 
082     def __getstate__(self):
083         state = self.__dict__.copy()
084         state['_devnull'] = None
085         return state
086      
087     def classify(self, vectors):
088         """
089         Classify feature vectors.
090         
091         @type vectors: C{tuple}
092         @param vectors: A tuple of sparse binary feature vectors.
093         @rtype: C{tuple}
094         @return: A tuple of C{float} vector classifications.
095         """
096         self._write_vectors(self._input_fname, ["0" for v in vectors], vectors)
097         ret = call((self._svm_classify, self._input_fname, self._model_fname, 
098                     self._output_fname), stdout=self.devnull)
099         assert ret == 0
100         results = self._read_classification()
101         assert len(results== len(vectors)
102         return results
103     
104     def _write_vectors(self, fname, labels, vectors):
105         # Writes the given array to the given filename with the given labels.
106         # Vectors are written in the SVMlight format.
107         file = open(fname, "w")
108         assert len(labels== len(vectors)
109         for i in range(0, len(labels)):
110             label = "-1"
111             if labels[i]:
112                 label = "1"
113             feature_strings = list()
114             features = vectors[i].keys()
115             features.sort()
116             for feature in features:
117                 feature_strings.append("%d:%s" % (feature + 1, 
118                                                   str(vectors[i][feature])))
119             file.write("%s %s\n" % (label, " ".join(feature_strings)))
120         file.close()
121     
122     def _write_model(self, fname, model):
123         # Writes the model file.
124         file = open(fname, "w")
125         for line in model:
126             file.write("%s\n" % line)
127         file.close()
128         
129     def _read_classification(self):
130         # Reads the SVMlight output file.
131         file = open(self._output_fname, "r")
132         result = []
133         for line in file.readlines():
134             result.append(float(line))
135         file.close()
136         assert len(result> 0
137         return result
138     
139     def _read_model(self):
140         # Reads the SVMlight model file.
141         file = open(self._model_fname, "r")
142         result = []
143         for line in file.readlines():
144             line = line.rstrip()
145             result.append(line)
146         file.close()
147         assert len(result> 0
148         return result
149         
150     def __del__(self):
151         if self._cleanup:
152             for fname in os.listdir(self._directory):
153                 os.unlink(os.sep.join((self._directory, fname)))
154             os.rmdir(self._directory)

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics