From ca6a5055ad8a9ea82e9f42aff3c906903d5e6df7 Mon Sep 17 00:00:00 2001
From: Victor Wagner <vitus@wagner.pp.ru>
Date: Sun, 21 Dec 2014 11:47:56 +0300
Subject: [PATCH] Fixes some style. Improved tests coverage for bio,oid,digest
 and cipher. Prepare digest for inheritance in future mac module

---
 ctypescrypto/bio.py    |  2 +-
 ctypescrypto/cipher.py |  2 +-
 ctypescrypto/digest.py | 30 +++++++++++++++++++++++-------
 ctypescrypto/oid.py    |  6 ++++--
 ctypescrypto/pkey.py   |  4 ++--
 ctypescrypto/rand.py   |  2 +-
 ctypescrypto/x509.py   |  6 ++++--
 tests/testbio.py       | 14 ++++++++++++++
 tests/testdigest.py    | 32 +++++++++++++++++++++++++++++++-
 tests/testoids.py      |  6 ++++++
 tests/testx509.py      | 27 +++++++++++++++++++++++++++
 11 files changed, 114 insertions(+), 17 deletions(-)

diff --git a/ctypescrypto/bio.py b/ctypescrypto/bio.py
index 2743f70..b1700b9 100644
--- a/ctypescrypto/bio.py
+++ b/ctypescrypto/bio.py
@@ -42,7 +42,7 @@ class Membio(object):
 		@param length - if specifed, limits amount of data read. If not BIO is read until end of buffer
 		"""
 		if not length is None:
-			if type(length)!=type(0):
+			if not isinstance(length,(int,long)):
 				raise TypeError("length to read should be number")
 			buf=create_string_buffer(length)
 			readbytes=libcrypto.BIO_read(self.bio,buf,length)
diff --git a/ctypescrypto/cipher.py b/ctypescrypto/cipher.py
index 3322725..815a425 100644
--- a/ctypescrypto/cipher.py
+++ b/ctypescrypto/cipher.py
@@ -171,7 +171,7 @@ class Cipher:
 		"""
 		if self.cipher_finalized :
 			raise CipherError("No updates allowed")
-		if type(data) != type(""):
+		if not isinstance(data,str):
 			raise TypeError("A string is expected")
 		if len(data) <= 0:
 			return ""
diff --git a/ctypescrypto/digest.py b/ctypescrypto/digest.py
index 2963e70..30d77c5 100644
--- a/ctypescrypto/digest.py
+++ b/ctypescrypto/digest.py
@@ -41,13 +41,29 @@ class DigestType(object):
 	"""
 	def __init__(self,	digest_name):
 		"""
-			Finds digest by its name
+			Finds digest by its name. You can pass Oid object instead of
+			name.
+
+			Special case is when None is passed as name. In this case
+			unitialized digest is created, and can be initalized later
+			by setting its digest attribute to pointer to EVP_MD
 		"""
-		self.digest_name = digest_name
-		self.digest = libcrypto.EVP_get_digestbyname(self.digest_name)
+		if digest_name is None:
+			return 
+		if isinstance(digest_name,Oid):
+			self.digest_name=digest_name.longname()
+			self.digest=libcrypto.EVP_get_digestbyname(self.digest_name)
+		else:
+			self.digest_name = str(digest_name)
+			self.digest = libcrypto.EVP_get_digestbyname(self.digest_name)
 		if self.digest is None:
 			raise DigestError("Unknown digest: %s" % self.digest_name)
 
+	@property
+	def name(self):
+		if not hasattr(self,'digest_name'):
+			self.digest_name=Oid(libcrypto.EVP_MD_type(self.digest)).longname()
+		return self.digest_name
 	def __del__(self):
 		pass
 	def digest_size(self):
@@ -69,7 +85,7 @@ class Digest(object):
 		"""
 		self._clean_ctx()
 		self.ctx = libcrypto.EVP_MD_CTX_create()
-		if self.ctx == 0:
+		if self.ctx is None:
 			raise DigestError("Unable to create digest context")
 		result = libcrypto.EVP_DigestInit_ex(self.ctx, digest_type.digest, None)
 		if result == 0:
@@ -92,11 +108,11 @@ class Digest(object):
 		"""
 		if self.digest_finalized:
 			raise DigestError("No updates allowed")
-		if type(data) != type(""):
+		if not isinstance(data,str):
 			raise TypeError("A string is expected")
 		if length is None:
-			length=len(data)
-		elif length> len(data):
+			length = len(data)
+		elif length > len(data):
 			raise ValueError("Specified length is greater than length of data")
 		result = libcrypto.EVP_DigestUpdate(self.ctx, c_char_p(data), length)
 		if result != 1:
diff --git a/ctypescrypto/oid.py b/ctypescrypto/oid.py
index 6941ce4..8caa57a 100644
--- a/ctypescrypto/oid.py
+++ b/ctypescrypto/oid.py
@@ -29,11 +29,13 @@ class Oid(object):
 
 	def __init__(self,value):
 		" Object constuctor. Accepts string or integer"
-		if type(value) == type(""):
+		if isinstance(value,unicode):
+			value=value.encode('ascii')
+		if isinstance(value,str):
 			self.nid=libcrypto.OBJ_txt2nid(value)
 			if self.nid==0:
 				raise ValueError("Cannot find object %s in the database"%(value))
-		elif type(value) == type(0):
+		elif isinstance(value,(int,long)):
 			cn=libcrypto.OBJ_nid2sn(value)
 			if cn is None:
 				raise ValueError("No such nid %d in the database"%(value))
diff --git a/ctypescrypto/pkey.py b/ctypescrypto/pkey.py
index 85b831d..fdf384d 100644
--- a/ctypescrypto/pkey.py
+++ b/ctypescrypto/pkey.py
@@ -241,9 +241,9 @@ class PKey(object):
 				continue
 			rv=libcrypto.EVP_PKEY_CTX_ctrl_str(ctx,oper,str(opts[oper]))
 			if rv==-2:
-				raise PKeyError("Parameter %s is not supported by key"%(oper))
+				raise PKeyError("Parameter %s is not supported by key"%(oper,))
 			if rv<1:
-				raise PKeyError("Error setting parameter %s"(oper))
+				raise PKeyError("Error setting parameter %s"%(oper,))
 # Declare function prototypes
 libcrypto.EVP_PKEY_cmp.argtypes=(c_void_p,c_void_p)
 libcrypto.PEM_read_bio_PrivateKey.restype=c_void_p
diff --git a/ctypescrypto/rand.py b/ctypescrypto/rand.py
index 4cc86e2..4bc073a 100644
--- a/ctypescrypto/rand.py
+++ b/ctypescrypto/rand.py
@@ -47,7 +47,7 @@ def seed(data, entropy=None):
 		If entropy is not None, it should be floating point(double)
 		value estimating amount of entropy  in the data (in bytes).
 	"""
-	if type(data) != type(""):
+	if not isinstance(data,str):
 		raise TypeError("A string is expected")
 	ptr = c_char_p(data)
 	size = len(data)
diff --git a/ctypescrypto/x509.py b/ctypescrypto/x509.py
index 44919f1..4f08632 100644
--- a/ctypescrypto/x509.py
+++ b/ctypescrypto/x509.py
@@ -148,7 +148,7 @@ class X509Name(object):
 			# Return first matching field
 			idx=libcrypto.X509_NAME_get_index_by_NID(self.ptr,key.nid,-1)
 			if idx<0:
-				raise KeyError("Key not found "+repr(Oid))
+				raise KeyError("Key not found "+str(Oid))
 			entry=libcrypto.X509_NAME_get_entry(self.ptr,idx)
 			s=libcrypto.X509_NAME_ENTRY_get_data(entry)
 			b=Membio()
@@ -168,6 +168,8 @@ class X509Name(object):
 	def __setitem__(self,key,val):
 		if not self.writable:
 			raise ValueError("Attempt to modify constant X509 object")
+		else:
+			raise NotImplementedError
 
 class _x509_ext(Structure):
 	""" Represens C structure X509_EXTENSION """
@@ -191,7 +193,6 @@ class X509_EXT(object):
 	def __str__(self):
 		b=Membio()
 		libcrypto.X509V3_EXT_print(b.bio,self.ptr,0x20010,0)
-		libcrypto.X509V3_EXT_print.argtypes=(c_void_p,POINTER(_x509_ext),c_long,c_int)
 		return str(b)
 	def __unicode__(self):
 		b=Membio()
@@ -547,3 +548,4 @@ libcrypto.X509_EXTENSION_dup.restype=POINTER(_x509_ext)
 libcrypto.X509V3_EXT_print.argtypes=(c_void_p,POINTER(_x509_ext),c_long,c_int)
 libcrypto.X509_get_ext.restype=c_void_p
 libcrypto.X509_get_ext.argtypes=(c_void_p,c_int)
+libcrypto.X509V3_EXT_print.argtypes=(c_void_p,POINTER(_x509_ext),c_long,c_int)
diff --git a/tests/testbio.py b/tests/testbio.py
index 1bd2e1f..c8fc313 100644
--- a/tests/testbio.py
+++ b/tests/testbio.py
@@ -6,8 +6,22 @@ class TestRead(unittest.TestCase):
 		s="A quick brown fox jumps over a lazy dog"
 		bio=Membio(s)
 		data=bio.read()
+		self.assertEqual(data,s)
+		data2=bio.read()
+		self.assertEqual(data2,"")
 		del bio
+	def test_readwithlen(self):
+		s="A quick brown fox jumps over a lazy dog"
+		bio=Membio(s)
+		data=bio.read(len(s))
 		self.assertEqual(data,s)
+		data2=bio.read(5)
+		self.assertEqual(data2,"")
+	def test_readwrongtype(self):
+		s="A quick brown fox jumps over a lazy dog"
+		bio=Membio(s)
+		with self.assertRaises(TypeError):
+			data=bio.read("5")
 	def test_reset(self):
 		s="A quick brown fox jumps over a lazy dog"
 		bio=Membio(s)
diff --git a/tests/testdigest.py b/tests/testdigest.py
index e840080..741c1d6 100644
--- a/tests/testdigest.py
+++ b/tests/testdigest.py
@@ -9,32 +9,58 @@ class TestDigestType(unittest.TestCase):
 		self.assertEqual(d.digest_size(),16)
 		self.assertEqual(d.block_size(),64)
 		self.assertEqual(d.oid(),Oid("md4"))
+		self.assertEqual(d.name,'md4')
 	def test_md5(self):
 		d=digest.DigestType("md5")
 		self.assertEqual(d.digest_size(),16)
 		self.assertEqual(d.block_size(),64)
 		self.assertEqual(d.oid(),Oid("md5"))
+		self.assertEqual(d.name,'md5')
 	def test_sha1(self):
 		d=digest.DigestType("sha1")
 		self.assertEqual(d.digest_size(),20)
 		self.assertEqual(d.block_size(),64)
 		self.assertEqual(d.oid(),Oid("sha1"))
+		self.assertEqual(d.name,'sha1')
 	def test_sha256(self):
 		d=digest.DigestType("sha256")
 		self.assertEqual(d.digest_size(),32)
 		self.assertEqual(d.block_size(),64)
 		self.assertEqual(d.oid(),Oid("sha256"))
+		self.assertEqual(d.name,'sha256')
 	def test_sha384(self):
 		d=digest.DigestType("sha384")
 		self.assertEqual(d.digest_size(),48)
 		self.assertEqual(d.block_size(),128)
 		self.assertEqual(d.oid(),Oid("sha384"))
+		self.assertEqual(d.name,'sha384')
 	def test_sha512(self):
 		d=digest.DigestType("sha512")
 		self.assertEqual(d.digest_size(),64)
 		self.assertEqual(d.block_size(),128)
 		self.assertEqual(d.oid(),Oid("sha512"))
-		
+		self.assertEqual(d.name,'sha512')
+	def test_createfromoid(self):
+		oid=Oid('sha256')
+		d=digest.DigestType(oid)
+		self.assertEqual(d.digest_size(),32)
+		self.assertEqual(d.block_size(),64)
+		self.assertEqual(d.oid(),Oid("sha256"))
+		self.assertEqual(d.name,'sha256')
+	def test_createfromEVP_MD(self):
+		d1=digest.DigestType("sha256")
+		d2=digest.DigestType(None)
+		with self.assertRaises(AttributeError):
+			s=d2.name
+		d2.digest=d1.digest
+		self.assertEqual(d2.digest_size(),32)
+		self.assertEqual(d2.block_size(),64)
+		self.assertEqual(d2.oid(),Oid("sha256"))
+		self.assertEqual(d2.name,'sha256')
+	def test_invalidDigest(self):
+		with self.assertRaises(digest.DigestError):
+			d=digest.DigestType("no-such-digest")
+
 
 class TestIface(unittest.TestCase):
 	""" Test all methods with one algorithms """
@@ -46,6 +72,10 @@ class TestIface(unittest.TestCase):
 		dgst.update(self.msg)
 		self.assertEqual(dgst.digest_size,20)
 		self.assertEqual(dgst.hexdigest(),self.dgst)
+	def test_digestwithdata(self):
+		md=digest.DigestType("sha1")
+		dgst=digest.Digest(md)
+		self.assertEqual(dgst.digest(self.msg),b16decode(self.dgst))
 	def test_length(self):
 		l=len(self.msg)
 		msg=self.msg+" Dog barks furiously."
diff --git a/tests/testoids.py b/tests/testoids.py
index f324f7e..0ff2564 100644
--- a/tests/testoids.py
+++ b/tests/testoids.py
@@ -24,12 +24,18 @@ class TestStandard(unittest.TestCase):
 		o=Oid("2.5.4.3")
 		x=Oid(o.nid)
 		self.assertEqual(o.nid,x.nid)
+	def test_fromunicode(self):
+		o=Oid(u'commonName')
+		self.assertEqual(o.shortname(),'CN')
 	def test_wrongoid(self):
 		with self.assertRaises(ValueError):
 			o=Oid("1.2.3.4.5.6.7.8.10.111.1111")
 	def test_wrongname(self):
 		with self.assertRaises(ValueError):
 			o=Oid("No such oid in the database")
+	def test_wrongnid(self):
+		with self.assertRaises(ValueError):
+			o=Oid(9999999)
 	def test_wrongtype(self):
 		with self.assertRaises(TypeError):
 			o=Oid([2,5,3,4])
diff --git a/tests/testx509.py b/tests/testx509.py
index a54417b..82b6d55 100644
--- a/tests/testx509.py
+++ b/tests/testx509.py
@@ -118,6 +118,12 @@ zVMSW4SOwg/H7ZMZ2cn6j1g0djIvruFQFGHUqFijyDATI+/GJYw2jxyA
 	def test_subject(self):
 		c=X509(self.cert1)
 		self.assertEqual(unicode(c.subject),u'C=RU,ST=Москва,L=Москва,O=Частное лицо,CN=Виктор Вагнер')
+	def test_subject_str(self):
+		c=X509(self.cert1)
+		self.assertEqual(str(c.subject),b'C=RU,ST=\\D0\\9C\\D0\\BE\\D1\\81\\D0\\BA\\D0\\B2\\D0\\B0,L=\\D0\\9C\\D0\\BE\\D1\\81\\D0\\BA\\D0\\B2\\D0\\B0,O=\\D0\\A7\\D0\\B0\\D1\\81\\D1\\82\\D0\\BD\\D0\\BE\\D0\\B5 \\D0\\BB\\D0\\B8\\D1\\86\\D0\\BE,CN=\\D0\\92\\D0\\B8\\D0\\BA\\D1\\82\\D0\\BE\\D1\\80 \\D0\\92\\D0\\B0\\D0\\B3\\D0\\BD\\D0\\B5\\D1\\80')
+	def test_subject_len(self):
+		c=X509(self.cert1)
+		self.assertEqual(len(c.subject),5)
 	def test_issuer(self):
 		c=X509(self.cert1)
 		self.assertEqual(unicode(c.issuer),u'C=RU,ST=Москва,O=Удостоверяющий центр,CN=Виктор Вагнер,emailAddress=vitus@wagner.pp.ru')
@@ -125,6 +131,19 @@ zVMSW4SOwg/H7ZMZ2cn6j1g0djIvruFQFGHUqFijyDATI+/GJYw2jxyA
 		c=X509(self.cert1)
 		self.assertEqual(c.subject[Oid("C")],"RU")
 		self.assertEqual(c.subject[Oid("L")],u'\u041c\u043e\u0441\u043a\u0432\u0430')
+	def test_subjectbadsubfield(self):
+		c=X509(self.cert1)
+		with self.assertRaises(KeyError):
+			x=c.subject[Oid("streetAddress")]
+	def test_subjectfieldindex(self):
+		c=X509(self.cert1)
+		self.assertEqual(repr(c.subject[0]),repr((Oid('C'),u'RU')))
+	def test_subjectbadindex(self):
+		c=X509(self.cert1)
+		with self.assertRaises(IndexError):
+			x=c.subject[11]
+		with self.assertRaises(IndexError):
+			x=c.subject[-1]
 	def test_notBefore(self):
 		c=X509(self.cert1)
 		self.assertEqual(c.startDate,datetime.datetime(2014,10,26,19,07,17,0,utc))
@@ -169,11 +188,16 @@ zVMSW4SOwg/H7ZMZ2cn6j1g0djIvruFQFGHUqFijyDATI+/GJYw2jxyA
 		cert=X509(self.cert1)
 		ext=cert.extensions[0]
 		self.assertEqual(str(ext),'CA:FALSE')
+		self.assertEqual(unicode(ext),u'CA:FALSE')
 	def test_extenson_find(self):
 		cert=X509(self.cert1)
 		exts=cert.extensions.find(Oid('subjectAltName'))
 		self.assertEqual(len(exts),1)
 		self.assertEqual(exts[0].oid,Oid('subjectAltName'))
+	def test_extension_bad_find(self):
+		cert=X509(self.cert1)
+		with self.assertRaises(TypeError):
+			exts=cert.extensions.find('subjectAltName')
 	def test_extenson_critical(self):
 		cert=X509(self.digicert_cert)
 		crit_exts=cert.extensions.find_critical()
@@ -190,6 +214,9 @@ zVMSW4SOwg/H7ZMZ2cn6j1g0djIvruFQFGHUqFijyDATI+/GJYw2jxyA
 		pk2=c.pubkey
 		self.assertFalse(c.verify(key=pk2))
 		self.assertTrue(c.verify(key=pubkey))
+	def test_verify_self_singed(self):
+		ca=X509(self.ca_cert)
+		self.assertTrue(ca.verify())
 	def test_default_filestore(self):
 		store=X509Store(default=True)
 		c1=X509(self.cert1)
-- 
2.39.5